Merge branch 'vici-python-timeout'

Closes strongswan/strongswan#1416
This commit is contained in:
Tobias Brunner 2022-12-12 14:38:46 +01:00
commit 3dd5dc5011
8 changed files with 151 additions and 30 deletions

View File

@ -4,9 +4,9 @@ EXTRA_DIST = LICENSE README.rst MANIFEST.in \
tox.sh \
test/__init__.py \
test/test_protocol.py \
test/test_session.py \
vici/__init__.py \
vici/command_wrappers.py \
vici/compat.py \
vici/exception.py \
vici/protocol.py \
vici/session.py

View File

@ -20,7 +20,6 @@ setup(
"Intended Audience :: System Administrators",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",

View File

@ -1,6 +1,7 @@
import pytest
import socket
from vici.protocol import Packet, Message, FiniteStream
from vici.protocol import Packet, Message, FiniteStream, Transport
from vici.exception import DeserializationException
@ -142,3 +143,26 @@ class TestMessage(object):
assert deserialized_message["key1"] == b"value1"
assert deserialized_section["sub-section"]["key2"] == b"value2"
assert deserialized_section["list1"] == [b"item1", b"item2"]
class TestTransport(object):
def interconnect(self):
c, s = socket.socketpair(socket.AF_UNIX)
return Transport(c), Transport(s)
def test_sendrecv(self):
c, s = self.interconnect()
c.send(b"foo")
assert s.receive() == b"foo"
s.send(b"foobarbaz")
s.send(b"")
assert c.receive() == b"foobarbaz"
assert c.receive() == b""
def test_timeout(self):
c, s = self.interconnect()
c.send(b"foo")
assert s.receive(timeout=1) == b"foo"
with pytest.raises(socket.timeout):
s.receive(timeout=0.1)

View File

@ -0,0 +1,100 @@
import pytest
import socket
import struct
from collections import OrderedDict
from vici.session import Session
from vici.protocol import Transport, Packet, Message, FiniteStream
from vici.exception import DeserializationException
class MockedServer(object):
def __init__(self, sock):
self.transport = Transport(sock)
def send(self, kind, name=None, message=None):
if name is None:
payload = struct.pack("!B", kind)
else:
name = name.encode("UTF-8")
payload = struct.pack("!BB", kind, len(name)) + name
if message is not None:
payload += Message.serialize(message)
self.transport.send(payload)
def recv(self):
stream = FiniteStream(self.transport.receive())
kind, length = struct.unpack("!BB", stream.read(2))
name = stream.read(length)
data = stream.read()
if len(data):
return kind, name, Message.deserialize(data)
return kind, name
class TestSession(object):
events = [
OrderedDict([('event', b'1')]),
OrderedDict([('event', b'2')]),
OrderedDict([('event', b'3')]),
]
def interconnect(self):
c, s = socket.socketpair(socket.AF_UNIX)
return Session(c), MockedServer(s)
def test_request(self):
c, s = self.interconnect()
s.send(Packet.CMD_RESPONSE)
assert c.request("doit") == {}
assert s.recv() == (Packet.CMD_REQUEST, b"doit")
s.send(Packet.CMD_RESPONSE, message={"hey": b"hou"})
assert c.request("heyhou") == {"hey": b"hou"}
assert s.recv() == (Packet.CMD_REQUEST, b"heyhou")
def test_streamed(self):
c, s = self.interconnect()
s.send(Packet.EVENT_CONFIRM)
for e in self.events:
s.send(Packet.EVENT, name="stream", message=e)
s.send(Packet.CMD_RESPONSE)
s.send(Packet.EVENT_CONFIRM)
assert list(c.streamed_request("streamit", "stream")) == self.events
assert s.recv() == (Packet.EVENT_REGISTER, b"stream")
assert s.recv() == (Packet.CMD_REQUEST, b"streamit")
assert s.recv() == (Packet.EVENT_UNREGISTER, b"stream")
def test_timeout(self):
c, s = self.interconnect()
s.send(Packet.EVENT_CONFIRM)
s.send(Packet.EVENT_CONFIRM)
for e in self.events:
s.send(Packet.EVENT, name="event", message=e)
r = []
i = 0
for name, msg in c.listen(["xyz", "event"], timeout=0.1):
if name is None:
i += 1
if i > 2:
s.send(Packet.EVENT, name="event", message={"late": b'1'})
s.send(Packet.EVENT_CONFIRM)
s.send(Packet.EVENT_CONFIRM)
break
else:
assert name == b"event"
r.append(msg)
assert s.recv() == (Packet.EVENT_REGISTER, b"xyz")
assert s.recv() == (Packet.EVENT_REGISTER, b"event")
assert s.recv() == (Packet.EVENT_UNREGISTER, b"xyz")
assert s.recv() == (Packet.EVENT_UNREGISTER, b"event")
assert r == self.events

View File

@ -1,5 +1,5 @@
[tox]
envlist = py27, py36, py37, py38, py39
envlist = py36, py37, py38, py39
[testenv]
deps =
@ -7,10 +7,6 @@ deps =
pytest-pycodestyle
commands = pytest --pycodestyle
[testenv:py{27}]
deps = pytest
commands = pytest
[pycodestyle]
max-line-length = 80
show-source = True

View File

@ -1,14 +0,0 @@
# Help functions for compatibility between python version 2 and 3
# From http://legacy.python.org/dev/peps/pep-0469
try:
dict.iteritems
except AttributeError:
# python 3
def iteritems(d):
return iter(d.items())
else:
# python 2
def iteritems(d):
return d.iteritems()

View File

@ -5,7 +5,6 @@ import struct
from collections import namedtuple
from collections import OrderedDict
from .compat import iteritems
from .exception import DeserializationException
@ -19,8 +18,8 @@ class Transport(object):
def send(self, packet):
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
def receive(self):
raw_length = self._recvall(self.HEADER_LENGTH)
def receive(self, timeout=None):
raw_length = self._recvall(self.HEADER_LENGTH, timeout)
length, = struct.unpack("!I", raw_length)
payload = self._recvall(length)
return payload
@ -29,11 +28,14 @@ class Transport(object):
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
def _recvall(self, count):
def _recvall(self, count, timeout=None):
"""Ensure to read count bytes from the socket"""
data = b""
if count > 0:
self.socket.settimeout(timeout)
while len(data) < count:
buf = self.socket.recv(count - len(data))
self.socket.settimeout(None)
if not buf:
raise socket.error('Connection closed')
data += buf
@ -121,7 +123,7 @@ class Message(object):
def serialize_dict(d):
segment = bytes()
for key, value in iteritems(d):
for key, value in d.items():
if isinstance(value, dict):
segment += (
encode_named_type(cls.SECTION_START, key)

View File

@ -40,7 +40,9 @@ class Session(CommandWrappers, object):
raise EventUnknownException(
"Unknown event type '{event}'".format(event=event_type)
)
elif response.response_type != Packet.EVENT_CONFIRM:
while response.response_type == Packet.EVENT:
response = Packet.parse(self.transport.receive())
if response.response_type != Packet.EVENT_CONFIRM:
raise SessionException(
"Unexpected response type {type}, "
"expected '{confirm}' (EVENT_CONFIRM)".format(
@ -139,11 +141,19 @@ class Session(CommandWrappers, object):
)
)
def listen(self, event_types):
def listen(self, event_types, timeout=None):
"""Register and listen for the given events.
If a timeout is given, the generator produces a (None, None) tuple
if no event has been received for that time. This allows the caller
to either abort by breaking from the generator, or perform periodic
tasks while staying registered within listen(), and then continue
waiting for more events.
:param event_types: event types to register
:type event_types: list
:param timeout: timeout to wait for events, in fractions of a second
:type timeout: float
:return: generator for streamed event responses as (event_type, dict)
:rtype: generator
"""
@ -152,7 +162,11 @@ class Session(CommandWrappers, object):
try:
while True:
response = Packet.parse(self.transport.receive())
try:
response = Packet.parse(self.transport.receive(timeout))
except socket.timeout:
yield None, None
continue
if response.response_type == Packet.EVENT:
try:
msg = Message.deserialize(response.payload)