#!/usr/bin/env python3 # # GENERATED FILE, DO NOT EDIT # # SPDX-License-Identifier: MIT # {#- this is a jinja template, warning above is for the generated file Non-obvious variables set by the scanner that are used in this template: - target: because eis is actually eis_client in the code, the target points to either "ei" or "eis_client". The various attributes on target resolve accordingly. - request.fqdn/event.fqdn - the full name of a request/event with the interface name prefixed, "ei_foo_request_bar" or "ei_foo_event_bar" - incoming/outgoing: points to the list of requests or events, depending which one is the outgoing one from the perspective of the file we're generating (ei or eis) #} from typing import Any, Callable, Generator, Tuple from enum import IntEnum from dataclasses import dataclass, field try: from enum import StrEnum except ImportError: from strenum import StrEnum import binascii import itertools import logging import struct import structlog import time # type aliases s = str i = int u = int x = int t = int o = int n = int f = float h = int # FIXME this should be a file-like object logger = structlog.get_logger() def hexlify(data): return binascii.hexlify(data, sep=" ", bytes_per_sep=4) class ObjectId(int): def __repr__(self) -> str: return f"{self:#x}" @dataclass class MethodCall: name: str args: dict[str, Any] objects: dict[str, "Interface"] = field(default_factory=dict) timestamp: float = field(default_factory=time.time) @dataclass class MessageHeader: object_id: ObjectId msglen: int opcode: int @classmethod def size(cls) -> int: return 16 @classmethod def from_data(cls, data: bytes) -> "MessageHeader": object_id, msglen, opcode = struct.unpack("=QII", data[:cls.size()]) return cls(ObjectId(object_id), msglen, opcode) @property def as_tuple(self) -> Tuple[int, int, int]: return self.object_id, self.msglen, self.opcode @dataclass class Context: objects: dict[str, "Interface"] = field(default_factory=dict) _callbacks: dict[str, dict[int, Callable]] = field( init=False, default_factory=lambda: { "register": {}, "unregister": {}} ) _ids: Generator = field(init=False, default_factory=itertools.count) def register(self, object: "Interface") -> None: assert object.object_id not in self.objects logger.debug(f"registering object", interface=object.name, object_id=f"{object.object_id:#x}") self.objects[object.object_id] = object for cb in self._callbacks["register"].values(): cb(object) def unregister(self, object: "Interface") -> None: assert object.object_id in self.objects logger.debug(f"unregistering object", interface=object.name, object=object, object_id=f"{object.object_id:#x}") del self.objects[object.object_id] for cb in self._callbacks["unregister"].values(): cb(object) def connect(self, signal: str, callback: Callable) -> int: cbs = self._callbacks[signal] id = next(self._ids) cbs[id] = callback return id def disconnect(self, signal: str, id: int) -> None: del self._callbacks[signal][id] def dispatch(self, data: bytes) -> None: if len(data) < MessageHeader.size(): return header = MessageHeader.from_data(data) object_id, opcode, msglen = header.object_id, header.opcode, header.msglen logger.debug(f"incoming packet ({msglen} bytes)", object_id=f"{object_id:x}", opcode=opcode, bytes=hexlify(data[:msglen])) try: dispatcher = self.objects[object_id] except KeyError: logger.error("Message from unknown object", object_id=f"{object_id:x}") return msglen try: logger.debug(f"incoming packet: dispatching", func=f"{dispatcher.name}.{dispatcher.incoming[opcode]}()", object=dispatcher) except KeyError: logger.error("Invalid opcode for object", object_id=f"{object_id:x}", opcode=opcode) return msglen consumed = dispatcher.dispatch(data, context=self) return consumed @classmethod def create(cls) -> "Context": o = cls() o.register(EiHandshake.create(object_id=0, version=1)) return o class InterfaceName(StrEnum): {% for interface in interfaces %} {{ interface.name.upper() }} = "{{interface.protocol_name}}" {% endfor %} @dataclass(eq=False) class Interface: object_id: int version: int callbacks: dict[str, Callable] = field(init=False, default_factory=dict, repr=False) calllog: list[MethodCall] = field(init=False, default_factory=list, repr=False) name: str = field(default="") incoming: dict[int, str] = field(default_factory=list, repr=False) outgoing: dict[int, str] = field(default_factory=list, repr=False) def format(self, *args, opcode: int, signature: str) -> bytes: encoding = ["=QII"] arguments = [] for sig, arg in zip(signature, args): if sig in ["u"]: encoding.append("I") elif sig in ["i"]: encoding.append("i") elif sig in ["f"]: encoding.append("f") elif sig in ["n", "o", "t"]: encoding.append("Q") elif sig in ["x"]: encoding.append("q") elif sig in ["s"]: encoding.append("I") arguments.append(len(arg) + 1) slen = ((len(arg) + 1 + 3) // 4) * 4 encoding.append(f"{slen}s") arg = arg.encode("utf8") elif sig in ["h"]: raise NotImplementedError("fd passing is not yet supported here") arguments.append(arg) format = "".join(encoding) length = struct.calcsize(format) header = MessageHeader(self.object_id, length, opcode) # logger.debug(f"Packing {encoding}: {arguments}") return struct.pack(format, *header.as_tuple, *arguments) def unpack(self, data, signature: str, names: list[str]) -> Tuple[int, dict[str, Any]]: encoding = ["=QII"] # the header for sig in signature: if sig in ["u"]: encoding.append("I") elif sig in ["i"]: encoding.append("i") elif sig in ["f"]: encoding.append("f") elif sig in ["x"]: encoding.append("q") elif sig in ["n", "o", "t"]: encoding.append("Q") elif sig in ["s"]: length_so_far = struct.calcsize("".join(encoding)) slen, = struct.unpack("I", data[length_so_far:length_so_far + 4]) slen = ((slen + 3) // 4) * 4 encoding.append(f"I{slen}s") elif sig in ["h"]: raise NotImplementedError("fd passing is not yet supported here") format = "".join(encoding) msglen = struct.calcsize(format) try: values = list(struct.unpack(format, data[:msglen])) except struct.error as e: logger.error(f"{e}", bytes=hexlify(data), length=len(data), encoding=format) raise e # logger.debug(f"unpacked {format} to {values}") results = [] values = values[3:] # drop id, length, opcode # we had to insert the string length into the format, filter the # value for that out again. for sig in signature: if sig in ["s"]: values.pop(0) s = values.pop(0) if not s: s = None # zero-length string is None else: s = s.decode("utf8").rstrip("\x00") # strip trailing zeroes results.append(s) else: results.append(values.pop(0)) # First two values are object_id and len|opcode return (msglen, { name: value for name, value in zip(names, results) }) def connect(self, event: str, callback: Callable): self.callbacks[event] = callback @classmethod def lookup(cls, name: str) -> "Interface": return { {% for interface in interfaces %} "{{interface.name}}": {{interface.camel_name}}, {% endfor %} }[name] {% for interface in interfaces %} @dataclass class {{interface.camel_name}}(Interface): {% for enum in interface.enums %} class {{component.capitalize()}}{{enum.camel_name}}(IntEnum): {% for entry in enum.entries %} {{entry.name.upper()}} = {{entry.value}} {% endfor %} {% endfor %} {% for outgoing in interface.outgoing %} def {{outgoing.camel_name}}(self{%- for arg in outgoing.arguments %}, {{arg.name}}: {{arg.signature}}{% endfor -%}) -> bytes: data = self.format({%- for arg in outgoing.arguments %}{{arg.name}}, {% endfor -%}opcode={{outgoing.opcode}}, signature="{{outgoing.signature}}") logger.debug("composing message", oject=self, func="{{interface.name}}.{{outgoing.name}}", args={ {%- for arg in outgoing.arguments %}"{{arg.name}}": {{arg.name}}, {% endfor -%} }, result=hexlify(data)) return data {% endfor %} {% for incoming in interface.incoming %} def on{{incoming.camel_name}}(self, context: Context{%- for arg in incoming.arguments %}, {{arg.name}}: {{arg.signature}}{% endfor -%}): new_objects = { {% for arg in incoming.arguments %} {% if arg.signature == "n" %} {% if arg.interface %} "{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, {{arg.version_arg.name}}), {% else %} "{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, {{arg.version_arg.name}}), {% endif %} {% endif %} {% endfor %} } for o in new_objects.values(): context.register(o) cb = self.callbacks.get("{{incoming.camel_name}}", None) if cb is not None: if new_objects: cb(self{%- for arg in incoming.arguments %}, {{arg.name}}{% endfor -%}, new_objects=new_objects) else: cb(self{%- for arg in incoming.arguments %}, {{arg.name}}{% endfor -%}) m = MethodCall(name="{{incoming.camel_name}}", args={ {% for arg in incoming.arguments %} "{{arg.name}}": {{arg.name}}, {% endfor %} }, objects=new_objects) self.calllog.append(m) {% if incoming.is_destructor %} context.unregister(self) {% endif %} {% endfor %} def dispatch(self, data: bytes, context: Context) -> int: header = MessageHeader.from_data(data) object_id, opcode = header.object_id, header.opcode if False: pass {% for incoming in interface.incoming %} elif opcode == {{incoming.opcode}}: consumed, args = self.unpack(data, signature="{{incoming.signature}}", names=[ {% for arg in incoming.arguments %} "{{arg.name}}", {% endfor %} ]) logger.debug("dispatching", object=self, func="{{incoming.camel_name}}", args=args) self.on{{incoming.camel_name}}(context, **args) {% endfor %} else: raise NotImplementedError(f"Invalid opcode {opcode}") return consumed @classmethod def create(cls, object_id: int, version: int): incoming = { {% for incoming in interface.incoming %} {{incoming.opcode}}: "{{incoming.name}}", {% endfor %} } outgoing = { {% for outgoing in interface.outgoing %} {{outgoing.opcode}}: "{{outgoing.name}}", {% endfor %} } return cls(object_id=object_id, version=version, name="{{interface.name}}", incoming=incoming, outgoing=outgoing) {% endfor %}