mirror of
https://gitlab.freedesktop.org/libinput/libei.git
synced 2025-12-31 15:50:11 +01:00
Continue to find nested ei_foo.bar.baz but avoid including a trailing full stop, if it exists. Part-of: <https://gitlab.freedesktop.org/libinput/libei/-/merge_requests/336>
1044 lines
34 KiB
Python
Executable file
1044 lines
34 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
EI protocol parser
|
|
|
|
This parser is intended to be generically useful for language bindings
|
|
other than libei/libeis. If it isn't, please file a bug.
|
|
|
|
When used as ei-scanner, it converts a Jinja2 template with the
|
|
scanned protocol. Otherwise, use the `parse()` function to
|
|
parse the protocol and return its structure as a set of Python
|
|
classes.
|
|
|
|
Opcodes for events and request are assigned in order as they
|
|
appear in the XML file.
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
from pathlib import Path
|
|
from textwrap import dedent
|
|
from dataclasses import dataclass, field
|
|
|
|
import argparse
|
|
import jinja2
|
|
import jinja2.environment
|
|
import os
|
|
import sys
|
|
import xml.sax
|
|
import xml.sax.handler
|
|
|
|
"""
|
|
Mapping of allowed protocol types to the single-character signature strings
|
|
used in the various code pieces.
|
|
"""
|
|
PROTOCOL_TYPES = {
|
|
"uint32": "u",
|
|
"int32": "i",
|
|
"uint64": "t",
|
|
"int64": "x",
|
|
"float": "f",
|
|
"fd": "h",
|
|
"new_id": "n",
|
|
"object": "o",
|
|
"string": "s",
|
|
}
|
|
|
|
|
|
def snake2camel(s: str) -> str:
|
|
"""
|
|
Convert snake_case to CamelCase (well, strictly speaking
|
|
PascalCase
|
|
"""
|
|
return s.replace("_", " ").title().replace(" ", "")
|
|
|
|
|
|
@dataclass
|
|
class Description:
|
|
summary: str = ""
|
|
text: str = ""
|
|
|
|
|
|
@dataclass
|
|
class Argument:
|
|
"""
|
|
Argument to a request or a reply
|
|
"""
|
|
|
|
name: str
|
|
protocol_type: str
|
|
summary: str
|
|
enum: Optional["Enum"]
|
|
interface: Optional["Interface"]
|
|
interface_arg: Optional["Argument"] = None
|
|
"""
|
|
For an argument with "interface_arg", this field points to the argument that
|
|
contains the interface name.
|
|
"""
|
|
interface_arg_for: Optional["Argument"] = None
|
|
"""
|
|
For an argument referenced by another argument through "interface_name", this field
|
|
points to the other argument that references this argument.
|
|
"""
|
|
version_arg: Optional["Argument"] = None
|
|
"""
|
|
For an argument with type "new_id", this field points to the argument that
|
|
contains the version for this new object.
|
|
"""
|
|
version_arg_for: Optional["Argument"] = None
|
|
"""
|
|
For an argument referenced by another argument of type "new_id", this field
|
|
points to the other argument that references this argument.
|
|
"""
|
|
allow_null: bool = False
|
|
"""
|
|
For an argument of type string, specify if the argument may be NULL.
|
|
"""
|
|
|
|
def __post_init(self):
|
|
if self.protocol_type is None or self.protocol_type not in PROTOCOL_TYPES:
|
|
raise ValueError(f"Failed to parse protocol_type {self.protocol_type}")
|
|
if self.interface is not None and self.signature not in ["n", "o"]:
|
|
raise ValueError("Interface may only be set for object types")
|
|
|
|
@property
|
|
def signature(self) -> str:
|
|
"""
|
|
The single-character signature for this argument
|
|
"""
|
|
return PROTOCOL_TYPES[self.protocol_type]
|
|
|
|
@property
|
|
def as_c_arg(self) -> str:
|
|
return f"{self.c_type} {self.name}"
|
|
|
|
@property
|
|
def c_type(self) -> str:
|
|
return {
|
|
"uint32": "uint32_t",
|
|
"int32": "int32_t",
|
|
"uint64": "uint64_t",
|
|
"int64": "int64_t",
|
|
"string": "const char *",
|
|
"fd": "int",
|
|
"float": "float",
|
|
"object": "object_id_t",
|
|
"new_id": "new_id_t",
|
|
}[self.protocol_type]
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
name: str,
|
|
protocol_type: str,
|
|
summary: str = "",
|
|
enum: Optional["Enum"] = None,
|
|
interface: Optional["Interface"] = None,
|
|
allow_null: bool = False,
|
|
) -> "Argument":
|
|
return cls(
|
|
name=name,
|
|
protocol_type=protocol_type,
|
|
summary=summary,
|
|
enum=enum,
|
|
interface=interface,
|
|
allow_null=allow_null,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
"""
|
|
Parent class for a wire message (Request or Event).
|
|
"""
|
|
|
|
name: str
|
|
since: int
|
|
opcode: int
|
|
interface: "Interface"
|
|
description: Optional[Description] = None
|
|
is_destructor: bool = False
|
|
context_type: Optional[str] = None
|
|
|
|
arguments: List[Argument] = field(init=False, default_factory=list)
|
|
|
|
def __post_init(self):
|
|
if self.context_type not in [None, "sender", "receiver"]:
|
|
raise ValueError(f"Invalid context type {self.context_type}")
|
|
|
|
def add_argument(self, arg: Argument) -> None:
|
|
if arg.name in [a.name for a in self.arguments]:
|
|
raise ValueError(f"Duplicate argument name '{arg.name}'")
|
|
self.arguments.append(arg)
|
|
|
|
@property
|
|
def num_arguments(self) -> int:
|
|
return len(self.arguments)
|
|
|
|
@property
|
|
def signature(self) -> str:
|
|
return "".join([a.signature for a in self.arguments])
|
|
|
|
@property
|
|
def camel_name(self) -> str:
|
|
return snake2camel(self.name)
|
|
|
|
def find_argument(self, name: str) -> Optional[Argument]:
|
|
for a in self.arguments:
|
|
if a.name == name:
|
|
return a
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class Request(Message):
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
name: str,
|
|
opcode: int,
|
|
interface: "Interface",
|
|
since: int = 1,
|
|
is_destructor: bool = False,
|
|
) -> "Request":
|
|
return cls(
|
|
name=name,
|
|
opcode=opcode,
|
|
since=since,
|
|
interface=interface,
|
|
is_destructor=is_destructor,
|
|
)
|
|
|
|
@property
|
|
def fqdn(self) -> str:
|
|
"""
|
|
The full name of this Request as <interface name>_request_<request name>
|
|
"""
|
|
return f"{self.interface.name}_request_{self.name}"
|
|
|
|
|
|
@dataclass
|
|
class Event(Message):
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
name: str,
|
|
opcode: int,
|
|
interface: "Interface",
|
|
since: int = 1,
|
|
is_destructor: bool = False,
|
|
) -> "Event":
|
|
return cls(
|
|
name=name,
|
|
opcode=opcode,
|
|
since=since,
|
|
interface=interface,
|
|
is_destructor=is_destructor,
|
|
)
|
|
|
|
@property
|
|
def fqdn(self) -> str:
|
|
"""
|
|
The full name of this Event as <interface name>_event_<request name>
|
|
"""
|
|
return f"{self.interface.name}_event_{self.name}"
|
|
|
|
|
|
@dataclass
|
|
class Entry:
|
|
"""
|
|
An enum entry
|
|
"""
|
|
|
|
name: str
|
|
value: int
|
|
enum: "Enum"
|
|
summary: str
|
|
since: int
|
|
|
|
@classmethod
|
|
def create(
|
|
cls, name: str, value: int, enum: "Enum", summary: str = "", since: int = 1
|
|
) -> "Entry":
|
|
return cls(name=name, value=value, enum=enum, summary=summary, since=since)
|
|
|
|
@property
|
|
def fqdn(self) -> str:
|
|
"""
|
|
The full name of this Entry as <interface name>_<enum name>_<entry name>
|
|
"""
|
|
return f"{self.enum.fqdn}_{self.name}"
|
|
|
|
|
|
@dataclass
|
|
class Enum:
|
|
name: str
|
|
since: int
|
|
interface: "Interface"
|
|
is_bitfield: bool = False
|
|
description: Optional[Description] = None
|
|
|
|
entries: List[Entry] = field(init=False, default_factory=list)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
name: str,
|
|
interface: "Interface",
|
|
since: int = 1,
|
|
is_bitfield: bool = False,
|
|
) -> "Enum":
|
|
return cls(name=name, since=since, interface=interface, is_bitfield=is_bitfield)
|
|
|
|
def add_entry(self, entry: Entry) -> None:
|
|
for e in self.entries:
|
|
if e.name == entry.name:
|
|
raise ValueError(f"Duplicate enum name '{entry.name}'")
|
|
|
|
if e.value == entry.value:
|
|
raise ValueError(f"Duplicate enum value '{entry.value}'")
|
|
|
|
if self.is_bitfield:
|
|
if e.value < 0:
|
|
raise ValueError("Bitmasks must not be less than zero")
|
|
try:
|
|
if e.value.bit_count() > 1:
|
|
raise ValueError("Bitmasks must have exactly one bit set")
|
|
except AttributeError:
|
|
pass # bit_count() requires Python 3.10
|
|
|
|
self.entries.append(entry)
|
|
|
|
@property
|
|
def fqdn(self):
|
|
"""
|
|
The full name of this Enum as <interface name>_<enum name>
|
|
"""
|
|
return f"{self.interface.name}_{self.name}"
|
|
|
|
@property
|
|
def camel_name(self) -> str:
|
|
return snake2camel(self.name)
|
|
|
|
|
|
@dataclass
|
|
class Interface:
|
|
protocol_name: str # name as in the XML, e.g. ei_pointer
|
|
version: int
|
|
requests: List[Request] = field(init=False, default_factory=list)
|
|
events: List[Event] = field(init=False, default_factory=list)
|
|
enums: List[Enum] = field(init=False, default_factory=list)
|
|
|
|
mode: str
|
|
description: Optional[Description] = None
|
|
|
|
def __post_init(self):
|
|
if self.mode not in ["ei", "eis", "brei"]:
|
|
raise ValueError(f"Invalid mode {self.mode}")
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""
|
|
Returns the mode-adjusted name of the interface, i.e. this may return
|
|
"ei_pointer", "eis_pointer", "brei_pointer", etc. depending on the
|
|
mode.
|
|
"""
|
|
return Interface.mangle_name(self.protocol_name, self.mode)
|
|
|
|
@property
|
|
def plainname(self) -> str:
|
|
"""
|
|
Returns the plain name of the interface, i.e. this returns
|
|
"pointer", "handshake", etc. without the "ei_" or "eis_" prefix.
|
|
"""
|
|
if self.protocol_name.startswith("ei_"):
|
|
return f"{self.protocol_name[3:]}"
|
|
return self.protocol_name
|
|
|
|
@staticmethod
|
|
def mangle_name(name: str, component: str) -> str:
|
|
"""
|
|
Returns the mangled interface name with the component as prefix (e.g. eis_device).
|
|
The XML only uses `ei_` as prefix, so let's replace that accordingly.
|
|
"""
|
|
if name.startswith("ei"):
|
|
return f"{component}{name[2:]}"
|
|
return name
|
|
|
|
def add_request(self, request: Request) -> None:
|
|
if request.name in [r.name for r in self.requests]:
|
|
raise ValueError(f"Duplicate request name '{request.name}'")
|
|
self.requests.append(request)
|
|
|
|
def add_event(self, event: Event) -> None:
|
|
if event.name in [r.name for r in self.events]:
|
|
raise ValueError(f"Duplicate event name '{event.name}'")
|
|
self.events.append(event)
|
|
|
|
def add_enum(self, enum: Enum) -> None:
|
|
if enum.name in [r.name for r in self.enums]:
|
|
raise ValueError(f"Duplicate enum name '{enum.name}'")
|
|
self.enums.append(enum)
|
|
|
|
def find_enum(self, name: str) -> Optional[Enum]:
|
|
for e in self.enums:
|
|
if e.name == name:
|
|
return e
|
|
return None
|
|
|
|
@property
|
|
def outgoing(self) -> List[Message]:
|
|
"""
|
|
Returns the list of messages outgoing from this implementation.
|
|
|
|
We use the same class for both ei and eis. To make the
|
|
template simpler, the class maps requests/events to
|
|
incoming/outgoing as correct relative to the implementation.
|
|
"""
|
|
if self.mode == "ei":
|
|
return self.requests # type: ignore
|
|
elif self.mode == "eis":
|
|
return self.events # type: ignore
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Interface.outgoing is not supported for mode {self.mode}"
|
|
)
|
|
|
|
@property
|
|
def incoming(self) -> List[Message]:
|
|
"""
|
|
Returns the list of messages incoming to this implementation.
|
|
|
|
We use the same class for both ei and eis. To make the
|
|
template simpler, the class maps requests/events to
|
|
incoming/outgoing as correct relative to the implementation.
|
|
"""
|
|
if self.mode == "ei":
|
|
return self.events # type: ignore
|
|
elif self.mode == "eis":
|
|
return self.requests # type: ignore
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Interface.incoming is not supported for mode {self.mode}"
|
|
)
|
|
|
|
@property
|
|
def c_type(self) -> str:
|
|
return f"struct {self.name} *"
|
|
|
|
@property
|
|
def as_c_arg(self) -> str:
|
|
return f"{self.c_type} {self.name}"
|
|
|
|
@property
|
|
def camel_name(self) -> str:
|
|
return snake2camel(self.name)
|
|
|
|
@classmethod
|
|
def create(cls, protocol_name: str, version: int, mode: str = "ei") -> "Interface":
|
|
assert mode in ["ei", "eis", "brei"]
|
|
return cls(protocol_name=protocol_name, version=version, mode=mode)
|
|
|
|
|
|
@dataclass
|
|
class XmlError(Exception):
|
|
line: int
|
|
column: int
|
|
message: str
|
|
|
|
def __str__(self) -> str:
|
|
return f"line {self.line}:{self.column}: {self.message}"
|
|
|
|
@classmethod
|
|
def create(cls, message: str, location: Tuple[int, int] = (0, 0)) -> "XmlError":
|
|
return cls(line=location[0], column=location[1], message=message)
|
|
|
|
|
|
@dataclass
|
|
class Copyright:
|
|
text: str = ""
|
|
is_complete: bool = field(init=False, default=False)
|
|
|
|
|
|
@dataclass
|
|
class Protocol:
|
|
copyright: Optional[str] = None
|
|
interfaces: List[Interface] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class ProtocolParser(xml.sax.handler.ContentHandler):
|
|
component: str
|
|
interfaces: List[Interface] = field(default_factory=list)
|
|
copyright: Optional[Copyright] = field(init=False, default=None)
|
|
|
|
current_interface: Optional[Interface] = field(init=False, default=None)
|
|
current_message: Optional[Union[Message, Enum]] = field(init=False, default=None)
|
|
current_description: Optional[Description] = field(init=False, default=None)
|
|
# A dict of arg name to interface_arg name mappings
|
|
current_interface_arg_names: Dict[str, str] = field(
|
|
init=False, default_factory=dict
|
|
)
|
|
current_new_id_arg: Optional[Argument] = field(init=False, default=None)
|
|
|
|
_run_counter: int = field(init=False, default=0, repr=False)
|
|
|
|
@property
|
|
def location(self) -> Tuple[int, int]:
|
|
line = self._locator.getLineNumber() # type: ignore
|
|
col = self._locator.getColumnNumber() # type: ignore
|
|
return line, col
|
|
|
|
def interface_by_name(self, protocol_name: str) -> Interface:
|
|
"""
|
|
Look up an interface by its protocol name (i.e. always "ei_foo", regardless of
|
|
what we're generating).
|
|
"""
|
|
try:
|
|
return [
|
|
iface
|
|
for iface in self.interfaces
|
|
if iface.protocol_name == protocol_name
|
|
].pop()
|
|
except IndexError:
|
|
raise XmlError.create(
|
|
f"Unable to find interface {protocol_name}", self.location
|
|
)
|
|
|
|
def startDocument(self):
|
|
self._run_counter += 1
|
|
|
|
def startElement(self, element: str, attrs: dict):
|
|
if element == "interface":
|
|
if self.current_interface is not None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' inside interface '{self.current_interface.name}'",
|
|
self.location,
|
|
)
|
|
|
|
try:
|
|
name = attrs["name"]
|
|
version = int(attrs["version"])
|
|
except KeyError as e:
|
|
raise XmlError.create(
|
|
f"Missing attribute {e} in element '{element}'",
|
|
self.location,
|
|
)
|
|
|
|
protocol_name = name
|
|
# We only create the interface on the first run, in subsequent runs we
|
|
# re-use them so we can cross reference correctly
|
|
if self._run_counter > 1:
|
|
intf = self.interface_by_name(protocol_name)
|
|
else:
|
|
intf = Interface.create(
|
|
protocol_name=protocol_name,
|
|
version=version,
|
|
mode=self.component,
|
|
)
|
|
self.interfaces.append(intf)
|
|
|
|
self.current_interface = intf
|
|
|
|
# first run only parses interfaces
|
|
if self._run_counter <= 1:
|
|
return
|
|
|
|
if element == "enum":
|
|
if self.current_interface is None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' outside an <interface>",
|
|
self.location,
|
|
)
|
|
if self.current_message is not None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' inside '{self.current_message.name}'",
|
|
self.location,
|
|
)
|
|
try:
|
|
name = attrs["name"]
|
|
since = int(attrs["since"])
|
|
except KeyError as e:
|
|
raise XmlError.create(
|
|
f"Missing attribute {e} in element '{element}'",
|
|
self.location,
|
|
)
|
|
if since > self.current_interface.version:
|
|
raise XmlError.create(
|
|
f"Invalid 'since' {since} for '{self.current_interface.name}.{name}'",
|
|
self.location,
|
|
)
|
|
|
|
try:
|
|
is_bitfield = {
|
|
"true": True,
|
|
"false": False,
|
|
}[attrs.get("bitfield", "false")]
|
|
except KeyError as e:
|
|
raise XmlError.create(
|
|
f"Invalid value {e} for boolean bitfield attribute in '{element}'",
|
|
self.location,
|
|
)
|
|
|
|
# We only create the enum on the second run, in subsequent runs
|
|
# we re-use them so we can cross-reference correctly
|
|
if self._run_counter > 2:
|
|
enum = self.current_interface.find_enum(name)
|
|
if enum is None:
|
|
raise XmlError.create(
|
|
f"Invalid enum {name}. This is a parser bug",
|
|
self.location,
|
|
)
|
|
else:
|
|
enum = Enum.create(
|
|
name=name,
|
|
since=since,
|
|
interface=self.current_interface,
|
|
is_bitfield=is_bitfield,
|
|
)
|
|
try:
|
|
self.current_interface.add_enum(enum)
|
|
except ValueError as e:
|
|
raise XmlError.create(str(e), self.location)
|
|
self.current_message = enum
|
|
|
|
# second run only parses enums
|
|
if self._run_counter <= 2:
|
|
return
|
|
|
|
if element == "request":
|
|
if self.current_interface is None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' outside an <interface>",
|
|
self.location,
|
|
)
|
|
|
|
try:
|
|
name = attrs["name"]
|
|
since = int(attrs["since"])
|
|
except KeyError as e:
|
|
raise XmlError.create(
|
|
f"Missing attribute {e} in element '{element}'",
|
|
self.location,
|
|
)
|
|
if since > self.current_interface.version:
|
|
raise XmlError.create(
|
|
f"Invalid 'since' {since} for '{self.current_interface.name}.{name}'",
|
|
self.location,
|
|
)
|
|
|
|
is_destructor = attrs.get("type", "") == "destructor"
|
|
opcode = len(self.current_interface.requests)
|
|
request = Request.create(
|
|
name=name,
|
|
since=since,
|
|
opcode=opcode,
|
|
interface=self.current_interface,
|
|
is_destructor=is_destructor,
|
|
)
|
|
request.context_type = attrs.get("context-type")
|
|
try:
|
|
self.current_interface.add_request(request)
|
|
except ValueError as e:
|
|
raise XmlError.create(str(e), self.location)
|
|
self.current_message = request
|
|
elif element == "event":
|
|
if self.current_interface is None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' outside an <interface>",
|
|
self.location,
|
|
)
|
|
if self.current_message is not None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' inside '{self.current_message.name}'",
|
|
self.location,
|
|
)
|
|
try:
|
|
name = attrs["name"]
|
|
since = int(attrs["since"])
|
|
except KeyError as e:
|
|
raise XmlError.create(
|
|
f"Missing attribute {e} in element '{element}'",
|
|
self.location,
|
|
)
|
|
if since > self.current_interface.version:
|
|
raise XmlError.create(
|
|
f"Invalid 'since' {since} for '{self.current_interface.name}.{name}'",
|
|
self.location,
|
|
)
|
|
|
|
is_destructor = attrs.get("type", "") == "destructor"
|
|
opcode = len(self.current_interface.events)
|
|
event = Event.create(
|
|
name=name,
|
|
since=since,
|
|
opcode=opcode,
|
|
interface=self.current_interface,
|
|
is_destructor=is_destructor,
|
|
)
|
|
event.context_type = attrs.get("context-type")
|
|
try:
|
|
self.current_interface.add_event(event)
|
|
except ValueError as e:
|
|
raise XmlError.create(str(e), self.location)
|
|
self.current_message = event
|
|
elif element == "arg":
|
|
if self.current_interface is None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' outside an <interface>",
|
|
self.location,
|
|
)
|
|
if not isinstance(self.current_message, Message):
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' must be inside <request> or <event>",
|
|
self.location,
|
|
)
|
|
name = attrs["name"]
|
|
proto_type = attrs["type"]
|
|
if proto_type not in PROTOCOL_TYPES:
|
|
raise XmlError.create(
|
|
f"Invalid type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}::{name}'",
|
|
self.location,
|
|
)
|
|
|
|
summary = attrs.get("summary", "")
|
|
interface_name = attrs.get("interface", None)
|
|
if interface_name is not None:
|
|
interface = self.interface_by_name(interface_name)
|
|
else:
|
|
interface = None
|
|
|
|
# interface_arg is set to the name of some other arg that specifies the actual
|
|
# interface name for this argument
|
|
interface_arg_name = attrs.get("interface_arg", None)
|
|
if interface_arg_name is not None:
|
|
self.current_interface_arg_names[name] = interface_arg_name
|
|
|
|
enum_name = attrs.get("enum", None)
|
|
enum = None
|
|
if enum_name is not None:
|
|
if "." in enum_name:
|
|
iname, enum_name = enum_name.split(".")
|
|
intf = self.interface_by_name(iname)
|
|
else:
|
|
intf = self.current_interface
|
|
|
|
enum = intf.find_enum(enum_name)
|
|
if enum is None:
|
|
raise XmlError.create(
|
|
f"Failed to find enum '{intf.name}.{enum_name}'",
|
|
self.location,
|
|
)
|
|
|
|
allow_null = attrs.get("allow-null", "false") == "true"
|
|
arg = Argument.create(
|
|
name=name,
|
|
protocol_type=proto_type,
|
|
summary=summary,
|
|
enum=enum,
|
|
interface=interface,
|
|
allow_null=allow_null,
|
|
)
|
|
self.current_message.add_argument(arg)
|
|
if proto_type == "new_id":
|
|
if self.current_new_id_arg is not None:
|
|
raise XmlError.create(
|
|
f"Multiple args of type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}'",
|
|
self.location,
|
|
)
|
|
self.current_new_id_arg = arg
|
|
elif element == "entry":
|
|
if self.current_interface is None:
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' outside an <interface>",
|
|
self.location,
|
|
)
|
|
if not isinstance(self.current_message, Enum):
|
|
raise XmlError.create(
|
|
f"Invalid element '{element}' must be inside <enum>",
|
|
self.location,
|
|
)
|
|
name = attrs["name"]
|
|
value = int(attrs["value"])
|
|
summary = attrs.get("summary", "")
|
|
since = int(attrs.get("since", 1))
|
|
entry = Entry.create(
|
|
name=name,
|
|
value=value,
|
|
enum=self.current_message,
|
|
summary=summary,
|
|
since=since,
|
|
)
|
|
try:
|
|
self.current_message.add_entry(entry)
|
|
except ValueError as e:
|
|
raise XmlError.create(str(e), self.location)
|
|
elif element == "description":
|
|
summary = attrs.get("summary", "")
|
|
self.current_description = Description(summary=summary)
|
|
elif element == "copyright":
|
|
if self.copyright is not None:
|
|
raise XmlError.create(
|
|
"Multiple <copyright> tags in file", self.location
|
|
)
|
|
self.copyright = Copyright()
|
|
|
|
def endElement(self, name):
|
|
if name == "interface":
|
|
assert self.current_interface is not None
|
|
self.current_interface = None
|
|
|
|
# first run only parses interfaces
|
|
if self._run_counter <= 1:
|
|
return
|
|
|
|
if name == "enum":
|
|
assert isinstance(self.current_message, Enum)
|
|
self.current_message = None
|
|
|
|
# second run only parses interfaces and enums
|
|
if self._run_counter <= 2:
|
|
return
|
|
|
|
# Populate `interface_arg` and `interface_arg_for`, now we have all arguments
|
|
if name in ["request", "event"]:
|
|
assert isinstance(self.current_message, Message)
|
|
assert isinstance(self.current_interface, Interface)
|
|
# obj is the argument of type object that the interface applies to
|
|
# iname is the argument of type "interface_name" that specifies the interface
|
|
for obj, iname in self.current_interface_arg_names.items():
|
|
obj_arg = self.current_message.find_argument(obj)
|
|
iname_arg = self.current_message.find_argument(iname)
|
|
|
|
assert obj_arg is not None
|
|
assert iname_arg is not None
|
|
|
|
obj_arg.interface_arg = iname_arg
|
|
iname_arg.interface_arg_for = obj_arg
|
|
self.current_interface_arg_names = {}
|
|
|
|
if self.current_new_id_arg is not None:
|
|
arg = self.current_new_id_arg
|
|
version_arg = self.current_message.find_argument("version")
|
|
if version_arg is None:
|
|
# Sigh, protocol bug: ei_connection.sync one doesn't have a version arg
|
|
if (
|
|
f"{self.current_interface.plainname}.{self.current_message.name}"
|
|
!= "connection.sync"
|
|
):
|
|
raise XmlError.create(
|
|
f"Unable to find a version argument for {self.current_interface.plainname}.{self.current_message.name}::{arg.name}",
|
|
self.location,
|
|
)
|
|
else:
|
|
arg.version_arg = version_arg
|
|
version_arg.version_arg_for = arg
|
|
self.current_new_id_arg = None
|
|
if name == "request":
|
|
assert isinstance(self.current_message, Request)
|
|
self.current_message = None
|
|
elif name == "event":
|
|
assert isinstance(self.current_message, Event)
|
|
self.current_message = None
|
|
elif name == "description":
|
|
assert self.current_description is not None
|
|
self.current_description.text = dedent(self.current_description.text)
|
|
if self.current_message is None:
|
|
assert self.current_interface is not None
|
|
self.current_interface.description = self.current_description
|
|
else:
|
|
self.current_message.description = self.current_description
|
|
self.current_description = None
|
|
elif name == "copyright":
|
|
assert self.copyright is not None
|
|
self.copyright.text = dedent(self.copyright.text)
|
|
self.copyright.is_complete = True
|
|
|
|
def characters(self, content):
|
|
if self.current_description is not None:
|
|
self.current_description.text += content
|
|
elif self.copyright is not None and not self.copyright.is_complete:
|
|
self.copyright.text += content
|
|
|
|
@classmethod
|
|
def create(cls, component: str) -> "ProtocolParser":
|
|
h = cls(component=component)
|
|
return h
|
|
|
|
|
|
def parse(protofile: Path, component: str) -> Protocol:
|
|
proto = ProtocolParser.create(component=component)
|
|
xml.sax.parse(os.fspath(protofile), proto)
|
|
# We parse three times, once to fetch all the interfaces, one for enums, then to parse the details
|
|
xml.sax.parse(os.fspath(protofile), proto)
|
|
xml.sax.parse(os.fspath(protofile), proto)
|
|
copyright = proto.copyright.text if proto.copyright else None
|
|
return Protocol(
|
|
copyright=copyright,
|
|
interfaces=proto.interfaces,
|
|
)
|
|
|
|
|
|
def generate_source(
|
|
proto: Protocol, template: str, component: str, extra_data: Optional[dict]
|
|
) -> jinja2.environment.TemplateStream:
|
|
assert component in ["ei", "eis", "brei"]
|
|
|
|
data: dict[str, Any] = {}
|
|
data["component"] = component
|
|
data["interfaces"] = proto.interfaces
|
|
data["extra"] = extra_data
|
|
|
|
loader: jinja2.BaseLoader
|
|
if template == "-":
|
|
loader = jinja2.FunctionLoader(lambda _: sys.stdin.read())
|
|
filename = "<stdin>"
|
|
else:
|
|
path = Path(template)
|
|
assert path.exists(), f"Failed to find template {path}"
|
|
filename = path.name
|
|
loader = jinja2.FileSystemLoader(os.fspath(path.parent))
|
|
|
|
env = jinja2.Environment(
|
|
loader=loader,
|
|
trim_blocks=True,
|
|
lstrip_blocks=True,
|
|
)
|
|
|
|
# jinja filter to convert foo into "struct foo *"
|
|
def filter_c_type(name):
|
|
return f"struct {name} *"
|
|
|
|
# jinja filter to convert foo into "struct foo *foo"
|
|
def filter_as_c_arg(name):
|
|
return f"struct {name} *{name}"
|
|
|
|
# escape any ei[s]?_foo.bar with markdown backticks
|
|
def filter_ei_escape_names(str, quotes="`"):
|
|
if not str:
|
|
return str
|
|
|
|
import re
|
|
|
|
return re.sub(rf"({component}[_-]\w*)((\.\w+)*)", rf"{quotes}\1\2{quotes}", str)
|
|
|
|
env.filters["c_type"] = filter_c_type
|
|
env.filters["as_c_arg"] = filter_as_c_arg
|
|
env.filters["camel"] = snake2camel
|
|
env.filters["ei_escape_names"] = filter_ei_escape_names
|
|
jtemplate = env.get_template(filename)
|
|
return jtemplate.stream(data)
|
|
|
|
|
|
def scanner(argv: list[str]) -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description=dedent(
|
|
"""
|
|
ei-scanner is a tool to parse the EI protocol description XML and
|
|
pass the data to a Jinja2 template. That template can then be
|
|
used to generate protocol bindings for the desired language.
|
|
|
|
typical usages:
|
|
ei-scanner --component=ei protocol.xml my-template.tpl
|
|
ei-scanner --component=eis --output=bindings.rs protocol.xml bindings.rs.tpl
|
|
|
|
Elements in the XML file are provided as variables with attributes
|
|
generally matching the XML file. For example, each interface has requests,
|
|
events and enums, and each of those has a name.
|
|
|
|
ei-scanner additionally provides the following values to the Jinja2 templates:
|
|
- interface.incoming and interface.outgoing: maps to the requests/events of
|
|
the interface, depending on the component.
|
|
- argument.signature: a single-character signature type mapping
|
|
from the protocol XML type:
|
|
uint32 -> "u"
|
|
int32 -> "i"
|
|
float -> "f"
|
|
fd -> "h"
|
|
new_id -> "n"
|
|
object -> "o"
|
|
string -> "s"
|
|
|
|
ei-scanner adds the following Jinja2 filters for convenience:
|
|
{{foo|c_type}} ... resolves to "struct foo *"
|
|
{{foo|as_c_arg}} ... resolves to "struct foo *foo"
|
|
{{foo_bar|camel}} ... resolves to "FooBar"
|
|
|
|
"""
|
|
),
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--component", type=str, choices=["ei", "eis", "brei"], default="ei"
|
|
)
|
|
parser.add_argument(
|
|
"--output", type=str, default="-", help="Output file to write to"
|
|
)
|
|
parser.add_argument("protocol", type=Path, help="The protocol XML file")
|
|
parser.add_argument(
|
|
"--jinja-extra-data",
|
|
type=str,
|
|
help="Extra data (in JSON format) to pass through to the Jinja template as 'extra'",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"--jinja-extra-data-file",
|
|
type=Path,
|
|
help="Path to file with extra data to pass through to the Jinja template as 'extra'",
|
|
default=None,
|
|
)
|
|
parser.add_argument(
|
|
"template", type=str, help="The Jinja2 compatible template file"
|
|
)
|
|
|
|
ns = parser.parse_args(argv)
|
|
assert ns.protocol.exists()
|
|
|
|
try:
|
|
proto = parse(
|
|
protofile=ns.protocol,
|
|
component=ns.component,
|
|
)
|
|
except xml.sax.SAXParseException as e:
|
|
print(f"Parser error: {e}", file=sys.stderr)
|
|
raise SystemExit(1)
|
|
except XmlError as e:
|
|
print(f"Protocol XML error: {e}", file=sys.stderr)
|
|
raise SystemExit(1)
|
|
|
|
if ns.jinja_extra_data is not None:
|
|
import json
|
|
|
|
extra_data = json.loads(ns.jinja_extra_data)
|
|
elif ns.jinja_extra_data_file is not None:
|
|
if ns.jinja_extra_data_file.name.endswith(
|
|
".yml"
|
|
) or ns.jinja_extra_data_file.name.endswith(".yaml"):
|
|
import yaml
|
|
|
|
with open(ns.jinja_extra_data_file) as fd:
|
|
extra_data = yaml.safe_load(fd)
|
|
elif ns.jinja_extra_data_file.name.endswith(".json"):
|
|
import json
|
|
|
|
with open(ns.jinja_extra_data_file) as fd:
|
|
extra_data = json.load(fd)
|
|
else:
|
|
print("Unknown file format for jinja data", file=sys.stderr)
|
|
raise SystemExit(1)
|
|
else:
|
|
extra_data = None
|
|
|
|
stream = generate_source(
|
|
proto=proto, template=ns.template, component=ns.component, extra_data=extra_data
|
|
)
|
|
|
|
file = sys.stdout if ns.output == "-" else open(ns.output, "w")
|
|
stream.dump(file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
scanner(sys.argv[1:])
|