libei/test/eiproto.py.tmpl
Peter Hutterer 552f6dcbd0 ei-scanner: expose version_arg and version_arg_for
Points to the correspoding "version" argument, or points back to the
argument this version argument is for.
2023-05-26 16:56:13 +10:00

343 lines
12 KiB
Python

#!/usr/bin/env python3
#
# GENERATED FILE, DO NOT EDIT
#
# SPDX-License-Identifier: MIT
#
{#- this is a jinja template, warning above is for the generated file
Non-obvious variables set by the scanner that are used in this template:
- target: because eis is actually eis_client in the code, the target points to
either "ei" or "eis_client". The various attributes on target resolve
accordingly.
- request.fqdn/event.fqdn - the full name of a request/event with the
interface name prefixed, "ei_foo_request_bar" or "ei_foo_event_bar"
- incoming/outgoing: points to the list of requests or events, depending
which one is the outgoing one from the perspective of the file we're
generating (ei or eis)
#}
from typing import Any, Callable, Generator, Tuple
from enum import IntEnum
try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum
import attr
import binascii
import itertools
import logging
import struct
import structlog
import time
# type aliases
s = str
i = int
u = int
x = int
t = int
o = int
n = int
f = float
h = int # FIXME this should be a file-like object
logger = structlog.get_logger()
def hexlify(data):
return binascii.hexlify(data, sep=" ", bytes_per_sep=4)
@attr.s
class MethodCall:
name: str = attr.ib()
args: dict[str, Any] = attr.ib()
objects: dict[str, "Interface"] = attr.ib(default=attr.Factory(dict))
timestamp: float = attr.ib(default=attr.Factory(time.time))
@attr.s
class MessageHeader:
object_id: int = attr.ib(repr=lambda id: f"{id:#x}")
msglen: int = attr.ib()
opcode: int = attr.ib()
@classmethod
def size(cls) -> int:
return 16
@classmethod
def from_data(cls, data: bytes) -> "MessageHeader":
object_id, msglen, opcode = struct.unpack("=QII", data[:cls.size()])
return cls(object_id, msglen, opcode)
@property
def as_tuple(self) -> Tuple[int, int, int]:
return self.object_id, self.msglen, self.opcode
@attr.s
class Context:
objects: dict[str, "Interface"] = attr.ib(default=attr.Factory(dict))
_callbacks: dict[str, dict[int, Callable]] = attr.ib(init=False)
_ids: Generator = attr.ib(init=False, default=attr.Factory(itertools.count))
@_callbacks.default # type: ignore
def _default_callbacks(self) -> dict[str, dict[int, Callable]]:
return { "register": {}, "unregister": {}}
def register(self, object: "Interface") -> None:
assert object.object_id not in self.objects
logger.debug(f"registering object", interface=object.name, object_id=f"{object.object_id:#x}")
self.objects[object.object_id] = object
for cb in self._callbacks["register"].values():
cb(object)
def unregister(self, object: "Interface") -> None:
assert object.object_id in self.objects
logger.debug(f"unregistering object", interface=object.name, object=object, object_id=f"{object.object_id:#x}")
del self.objects[object.object_id]
for cb in self._callbacks["unregister"].values():
cb(object)
def connect(self, signal: str, callback: Callable) -> int:
cbs = self._callbacks[signal]
id = next(self._ids)
cbs[id] = callback
return id
def disconnect(self, signal: str, id: int) -> None:
del self._callbacks[signal][id]
def dispatch(self, data: bytes) -> None:
if len(data) < MessageHeader.size():
return
header = MessageHeader.from_data(data)
object_id, opcode, msglen = header.object_id, header.opcode, header.msglen
logger.debug(f"incoming packet ({msglen} bytes)", object_id=f"{object_id:x}", opcode=opcode, bytes=hexlify(data[:msglen]))
try:
dispatcher = self.objects[object_id]
except KeyError:
logger.error("Message from unknown object", object_id=f"{object_id:x}")
return msglen
try:
logger.debug(f"incoming packet: dispatching", func=f"{dispatcher.name}.{dispatcher.incoming[opcode]}()", object=dispatcher)
except KeyError:
logger.error("Invalid opcode for object", object_id=f"{object_id:x}", opcode=opcode)
return msglen
consumed = dispatcher.dispatch(data, context=self)
return consumed
@classmethod
def create(cls) -> "Context":
o = cls()
o.register(EiHandshake.create(object_id=0, version=1))
return o
class InterfaceName(StrEnum):
{% for interface in interfaces %}
{{ interface.name.upper() }} = "{{interface.protocol_name}}"
{% endfor %}
@attr.s(eq=False)
class Interface:
object_id: int = attr.ib(repr=lambda id: f"{id:#x}")
version: int = attr.ib()
callbacks: dict[str, Callable] = attr.ib(init=False, default=attr.Factory(dict), repr=False)
calllog: list[MethodCall] = attr.ib(init=False, default=attr.Factory(list), repr=False)
name: str = attr.ib(default="<overridden by subclass>")
incoming: dict[int, str] = attr.ib(default=attr.Factory(list), repr=False)
outgoing: dict[int, str] = attr.ib(default=attr.Factory(list), repr=False)
def format(self, *args, opcode: int, signature: str) -> bytes:
encoding = ["=QII"]
arguments = []
for sig, arg in zip(signature, args):
if sig in ["u"]:
encoding.append("I")
elif sig in ["i"]:
encoding.append("i")
elif sig in ["f"]:
encoding.append("f")
elif sig in ["n", "o", "t"]:
encoding.append("Q")
elif sig in ["x"]:
encoding.append("q")
elif sig in ["s"]:
encoding.append("I")
arguments.append(len(arg) + 1)
slen = ((len(arg) + 3) // 4) * 4
encoding.append(f"{slen}s")
arg = arg.encode("utf8")
elif sig in ["h"]:
raise NotImplementedError("fd passing is not yet supported here")
arguments.append(arg)
format = "".join(encoding)
length = struct.calcsize(format)
header = MessageHeader(self.object_id, length, opcode)
# logger.debug(f"Packing {encoding}: {arguments}")
return struct.pack(format, *header.as_tuple, *arguments)
def unpack(self, data, signature: str, names: list[str]) -> Tuple[int, dict[str, Any]]:
encoding = ["=QII"] # the header
for sig in signature:
if sig in ["u"]:
encoding.append("I")
elif sig in ["i"]:
encoding.append("i")
elif sig in ["f"]:
encoding.append("f")
elif sig in ["x"]:
encoding.append("q")
elif sig in ["n", "o", "t"]:
encoding.append("Q")
elif sig in ["s"]:
length_so_far = struct.calcsize("".join(encoding))
slen, = struct.unpack("I", data[length_so_far:length_so_far + 4])
slen = ((slen + 3) // 4) * 4
encoding.append(f"I{slen}s")
elif sig in ["h"]:
raise NotImplementedError("fd passing is not yet supported here")
format = "".join(encoding)
msglen = struct.calcsize(format)
try:
values = list(struct.unpack(format, data[:msglen]))
except struct.error as e:
logger.error(f"{e}", bytes=hexlify(data), length=len(data), encoding=format)
raise e
# logger.debug(f"unpacked {format} to {values}")
results = []
values = values[3:] # drop id, length, opcode
# we had to insert the string length into the format, filter the
# value for that out again.
for sig in signature:
if sig in ["s"]:
values.pop(0)
s = values.pop(0)
if not s:
s = None # zero-length string is None
else:
s = s.decode("utf8").rstrip("\x00") # strip trailing zeroes
results.append(s)
else:
results.append(values.pop(0))
# First two values are object_id and len|opcode
return (msglen, { name: value for name, value in zip(names, results) })
def connect(self, event: str, callback: Callable):
self.callbacks[event] = callback
@classmethod
def lookup(cls, name: str) -> "Interface":
return {
{% for interface in interfaces %}
"{{interface.name}}": {{interface.camel_name}},
{% endfor %}
}[name]
{% for interface in interfaces %}
@attr.s
class {{interface.camel_name}}(Interface):
{% for enum in interface.enums %}
class {{component.capitalize()}}{{enum.camel_name}}(IntEnum):
{% for entry in enum.entries %}
{{entry.name.upper()}} = {{entry.value}}
{% endfor %}
{% endfor %}
{% for outgoing in interface.outgoing %}
def {{outgoing.camel_name}}(self{%- for arg in outgoing.arguments %}, {{arg.name}}: {{arg.signature}}{% endfor -%}) -> bytes:
data = self.format({%- for arg in outgoing.arguments %}{{arg.name}}, {% endfor -%}opcode={{outgoing.opcode}}, signature="{{outgoing.signature}}")
logger.debug("composing message", oject=self, func="{{interface.name}}.{{outgoing.name}}", args={ {%- for arg in outgoing.arguments %}"{{arg.name}}": {{arg.name}}, {% endfor -%} }, result=hexlify(data))
return data
{% endfor %}
{% for incoming in interface.incoming %}
def on{{incoming.camel_name}}(self, context: Context{%- for arg in incoming.arguments %}, {{arg.name}}: {{arg.signature}}{% endfor -%}):
new_objects = {
{% for arg in incoming.arguments %}
{% if arg.signature == "n" %}
{% if arg.interface %}
"{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, {{arg.version_arg.name}}),
{% else %}
"{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, {{arg.version_arg.name}}),
{% endif %}
{% endif %}
{% endfor %}
}
for o in new_objects.values():
context.register(o)
cb = self.callbacks.get("{{incoming.camel_name}}", None)
if cb is not None:
if new_objects:
cb(self{%- for arg in incoming.arguments %}, {{arg.name}}{% endfor -%}, new_objects=new_objects)
else:
cb(self{%- for arg in incoming.arguments %}, {{arg.name}}{% endfor -%})
m = MethodCall(name="{{incoming.camel_name}}", args={
{% for arg in incoming.arguments %}
"{{arg.name}}": {{arg.name}},
{% endfor %}
}, objects=new_objects)
self.calllog.append(m)
{% if incoming.is_destructor %}
context.unregister(self)
{% endif %}
{% endfor %}
def dispatch(self, data: bytes, context: Context) -> int:
header = MessageHeader.from_data(data)
object_id, opcode = header.object_id, header.opcode
if False:
pass
{% for incoming in interface.incoming %}
elif opcode == {{incoming.opcode}}:
consumed, args = self.unpack(data, signature="{{incoming.signature}}", names=[
{% for arg in incoming.arguments %}
"{{arg.name}}",
{% endfor %}
])
logger.debug("dispatching", object=self, func="{{incoming.camel_name}}", args=args)
self.on{{incoming.camel_name}}(context, **args)
{% endfor %}
else:
raise NotImplementedError(f"Invalid opcode {opcode}")
return consumed
@classmethod
def create(cls, object_id: int, version: int):
incoming = {
{% for incoming in interface.incoming %}
{{incoming.opcode}}: "{{incoming.name}}",
{% endfor %}
}
outgoing = {
{% for outgoing in interface.outgoing %}
{{outgoing.opcode}}: "{{outgoing.name}}",
{% endfor %}
}
return cls(object_id=object_id, version=version, name="{{interface.name}}", incoming=incoming, outgoing=outgoing)
{% endfor %}