mirror of
https://github.com/strongswan/strongswan.git
synced 2025-10-04 00:00:14 -04:00
commit
3dd5dc5011
@ -4,9 +4,9 @@ EXTRA_DIST = LICENSE README.rst MANIFEST.in \
|
||||
tox.sh \
|
||||
test/__init__.py \
|
||||
test/test_protocol.py \
|
||||
test/test_session.py \
|
||||
vici/__init__.py \
|
||||
vici/command_wrappers.py \
|
||||
vici/compat.py \
|
||||
vici/exception.py \
|
||||
vici/protocol.py \
|
||||
vici/session.py
|
||||
|
@ -20,7 +20,6 @@ setup(
|
||||
"Intended Audience :: System Administrators",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 2.7",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import socket
|
||||
|
||||
from vici.protocol import Packet, Message, FiniteStream
|
||||
from vici.protocol import Packet, Message, FiniteStream, Transport
|
||||
from vici.exception import DeserializationException
|
||||
|
||||
|
||||
@ -142,3 +143,26 @@ class TestMessage(object):
|
||||
assert deserialized_message["key1"] == b"value1"
|
||||
assert deserialized_section["sub-section"]["key2"] == b"value2"
|
||||
assert deserialized_section["list1"] == [b"item1", b"item2"]
|
||||
|
||||
|
||||
class TestTransport(object):
|
||||
|
||||
def interconnect(self):
|
||||
c, s = socket.socketpair(socket.AF_UNIX)
|
||||
return Transport(c), Transport(s)
|
||||
|
||||
def test_sendrecv(self):
|
||||
c, s = self.interconnect()
|
||||
c.send(b"foo")
|
||||
assert s.receive() == b"foo"
|
||||
s.send(b"foobarbaz")
|
||||
s.send(b"")
|
||||
assert c.receive() == b"foobarbaz"
|
||||
assert c.receive() == b""
|
||||
|
||||
def test_timeout(self):
|
||||
c, s = self.interconnect()
|
||||
c.send(b"foo")
|
||||
assert s.receive(timeout=1) == b"foo"
|
||||
with pytest.raises(socket.timeout):
|
||||
s.receive(timeout=0.1)
|
||||
|
100
src/libcharon/plugins/vici/python/test/test_session.py
Normal file
100
src/libcharon/plugins/vici/python/test/test_session.py
Normal file
@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import socket
|
||||
import struct
|
||||
from collections import OrderedDict
|
||||
|
||||
from vici.session import Session
|
||||
from vici.protocol import Transport, Packet, Message, FiniteStream
|
||||
from vici.exception import DeserializationException
|
||||
|
||||
|
||||
class MockedServer(object):
|
||||
|
||||
def __init__(self, sock):
|
||||
self.transport = Transport(sock)
|
||||
|
||||
def send(self, kind, name=None, message=None):
|
||||
if name is None:
|
||||
payload = struct.pack("!B", kind)
|
||||
else:
|
||||
name = name.encode("UTF-8")
|
||||
payload = struct.pack("!BB", kind, len(name)) + name
|
||||
if message is not None:
|
||||
payload += Message.serialize(message)
|
||||
self.transport.send(payload)
|
||||
|
||||
def recv(self):
|
||||
stream = FiniteStream(self.transport.receive())
|
||||
kind, length = struct.unpack("!BB", stream.read(2))
|
||||
name = stream.read(length)
|
||||
data = stream.read()
|
||||
if len(data):
|
||||
return kind, name, Message.deserialize(data)
|
||||
return kind, name
|
||||
|
||||
|
||||
class TestSession(object):
|
||||
|
||||
events = [
|
||||
OrderedDict([('event', b'1')]),
|
||||
OrderedDict([('event', b'2')]),
|
||||
OrderedDict([('event', b'3')]),
|
||||
]
|
||||
|
||||
def interconnect(self):
|
||||
c, s = socket.socketpair(socket.AF_UNIX)
|
||||
return Session(c), MockedServer(s)
|
||||
|
||||
def test_request(self):
|
||||
c, s = self.interconnect()
|
||||
|
||||
s.send(Packet.CMD_RESPONSE)
|
||||
assert c.request("doit") == {}
|
||||
assert s.recv() == (Packet.CMD_REQUEST, b"doit")
|
||||
|
||||
s.send(Packet.CMD_RESPONSE, message={"hey": b"hou"})
|
||||
assert c.request("heyhou") == {"hey": b"hou"}
|
||||
assert s.recv() == (Packet.CMD_REQUEST, b"heyhou")
|
||||
|
||||
def test_streamed(self):
|
||||
c, s = self.interconnect()
|
||||
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
for e in self.events:
|
||||
s.send(Packet.EVENT, name="stream", message=e)
|
||||
s.send(Packet.CMD_RESPONSE)
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
|
||||
assert list(c.streamed_request("streamit", "stream")) == self.events
|
||||
assert s.recv() == (Packet.EVENT_REGISTER, b"stream")
|
||||
assert s.recv() == (Packet.CMD_REQUEST, b"streamit")
|
||||
assert s.recv() == (Packet.EVENT_UNREGISTER, b"stream")
|
||||
|
||||
def test_timeout(self):
|
||||
c, s = self.interconnect()
|
||||
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
for e in self.events:
|
||||
s.send(Packet.EVENT, name="event", message=e)
|
||||
|
||||
r = []
|
||||
i = 0
|
||||
for name, msg in c.listen(["xyz", "event"], timeout=0.1):
|
||||
if name is None:
|
||||
i += 1
|
||||
if i > 2:
|
||||
s.send(Packet.EVENT, name="event", message={"late": b'1'})
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
s.send(Packet.EVENT_CONFIRM)
|
||||
break
|
||||
else:
|
||||
assert name == b"event"
|
||||
r.append(msg)
|
||||
|
||||
assert s.recv() == (Packet.EVENT_REGISTER, b"xyz")
|
||||
assert s.recv() == (Packet.EVENT_REGISTER, b"event")
|
||||
assert s.recv() == (Packet.EVENT_UNREGISTER, b"xyz")
|
||||
assert s.recv() == (Packet.EVENT_UNREGISTER, b"event")
|
||||
|
||||
assert r == self.events
|
@ -1,5 +1,5 @@
|
||||
[tox]
|
||||
envlist = py27, py36, py37, py38, py39
|
||||
envlist = py36, py37, py38, py39
|
||||
|
||||
[testenv]
|
||||
deps =
|
||||
@ -7,10 +7,6 @@ deps =
|
||||
pytest-pycodestyle
|
||||
commands = pytest --pycodestyle
|
||||
|
||||
[testenv:py{27}]
|
||||
deps = pytest
|
||||
commands = pytest
|
||||
|
||||
[pycodestyle]
|
||||
max-line-length = 80
|
||||
show-source = True
|
||||
|
@ -1,14 +0,0 @@
|
||||
# Help functions for compatibility between python version 2 and 3
|
||||
|
||||
|
||||
# From http://legacy.python.org/dev/peps/pep-0469
|
||||
try:
|
||||
dict.iteritems
|
||||
except AttributeError:
|
||||
# python 3
|
||||
def iteritems(d):
|
||||
return iter(d.items())
|
||||
else:
|
||||
# python 2
|
||||
def iteritems(d):
|
||||
return d.iteritems()
|
@ -5,7 +5,6 @@ import struct
|
||||
from collections import namedtuple
|
||||
from collections import OrderedDict
|
||||
|
||||
from .compat import iteritems
|
||||
from .exception import DeserializationException
|
||||
|
||||
|
||||
@ -19,8 +18,8 @@ class Transport(object):
|
||||
def send(self, packet):
|
||||
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
|
||||
|
||||
def receive(self):
|
||||
raw_length = self._recvall(self.HEADER_LENGTH)
|
||||
def receive(self, timeout=None):
|
||||
raw_length = self._recvall(self.HEADER_LENGTH, timeout)
|
||||
length, = struct.unpack("!I", raw_length)
|
||||
payload = self._recvall(length)
|
||||
return payload
|
||||
@ -29,11 +28,14 @@ class Transport(object):
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
self.socket.close()
|
||||
|
||||
def _recvall(self, count):
|
||||
def _recvall(self, count, timeout=None):
|
||||
"""Ensure to read count bytes from the socket"""
|
||||
data = b""
|
||||
if count > 0:
|
||||
self.socket.settimeout(timeout)
|
||||
while len(data) < count:
|
||||
buf = self.socket.recv(count - len(data))
|
||||
self.socket.settimeout(None)
|
||||
if not buf:
|
||||
raise socket.error('Connection closed')
|
||||
data += buf
|
||||
@ -121,7 +123,7 @@ class Message(object):
|
||||
|
||||
def serialize_dict(d):
|
||||
segment = bytes()
|
||||
for key, value in iteritems(d):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
segment += (
|
||||
encode_named_type(cls.SECTION_START, key)
|
||||
|
@ -40,7 +40,9 @@ class Session(CommandWrappers, object):
|
||||
raise EventUnknownException(
|
||||
"Unknown event type '{event}'".format(event=event_type)
|
||||
)
|
||||
elif response.response_type != Packet.EVENT_CONFIRM:
|
||||
while response.response_type == Packet.EVENT:
|
||||
response = Packet.parse(self.transport.receive())
|
||||
if response.response_type != Packet.EVENT_CONFIRM:
|
||||
raise SessionException(
|
||||
"Unexpected response type {type}, "
|
||||
"expected '{confirm}' (EVENT_CONFIRM)".format(
|
||||
@ -139,11 +141,19 @@ class Session(CommandWrappers, object):
|
||||
)
|
||||
)
|
||||
|
||||
def listen(self, event_types):
|
||||
def listen(self, event_types, timeout=None):
|
||||
"""Register and listen for the given events.
|
||||
|
||||
If a timeout is given, the generator produces a (None, None) tuple
|
||||
if no event has been received for that time. This allows the caller
|
||||
to either abort by breaking from the generator, or perform periodic
|
||||
tasks while staying registered within listen(), and then continue
|
||||
waiting for more events.
|
||||
|
||||
:param event_types: event types to register
|
||||
:type event_types: list
|
||||
:param timeout: timeout to wait for events, in fractions of a second
|
||||
:type timeout: float
|
||||
:return: generator for streamed event responses as (event_type, dict)
|
||||
:rtype: generator
|
||||
"""
|
||||
@ -152,7 +162,11 @@ class Session(CommandWrappers, object):
|
||||
|
||||
try:
|
||||
while True:
|
||||
response = Packet.parse(self.transport.receive())
|
||||
try:
|
||||
response = Packet.parse(self.transport.receive(timeout))
|
||||
except socket.timeout:
|
||||
yield None, None
|
||||
continue
|
||||
if response.response_type == Packet.EVENT:
|
||||
try:
|
||||
msg = Message.deserialize(response.payload)
|
||||
|
Loading…
x
Reference in New Issue
Block a user