diff --git a/proto/scanner.py b/proto/scanner.py index 8211dd8..c4ba011 100755 --- a/proto/scanner.py +++ b/proto/scanner.py @@ -71,6 +71,12 @@ class Argument: signature: str = attr.ib(converter=proto_to_type) summary: str = attr.ib() enum: Optional["Enum"] = attr.ib() + interface: Optional["Interface"] = attr.ib() + + @interface.validator # type: ignore + def _validate_interface(self, attribute, value): + if value is not None and self.signature not in ["n", "o"]: + raise ValueError("Interface may only be set for object types") @property def as_arg(self) -> str: @@ -107,9 +113,14 @@ class Argument: @classmethod def create( - cls, name: str, signature: str, summary: str = "", enum: Optional["Enum"] = None + cls, + name: str, + signature: str, + summary: str = "", + enum: Optional["Enum"] = None, + interface: Optional["Interface"] = None, ) -> "Argument": - return cls(name, signature, summary, enum) + return cls(name, signature, summary, enum, interface=interface) @attr.s @@ -291,6 +302,8 @@ class Protocol(xml.sax.handler.ContentHandler): current_interface: Optional[Interface] = attr.ib(init=False, default=None) current_message: Optional[Union[Message, Enum]] = attr.ib(init=False, default=None) + run: int = attr.ib(init=False, default=0) + def xmlerror(self, msg) -> None: line = self._locator.getLineNumber() # type: ignore col = self._locator.getColumnNumber() # type: ignore @@ -303,6 +316,15 @@ class Protocol(xml.sax.handler.ContentHandler): col = self._locator.getColumnNumber() # type: ignore return line, col + def interface_by_name(self, name) -> Interface: + try: + return [iface for iface in self.interfaces if iface.name == name].pop() + except IndexError: + raise XmlError(*self.location, f"Unable to find interface {name}") + + def startDocument(self): + self.run += 1 + def startElement(self, element: str, attrs: dict): if element == "interface": if self.current_interface is not None: @@ -321,10 +343,21 @@ class Protocol(xml.sax.handler.ContentHandler): if name.startswith("ei"): name = f"{self.component}{name[2:]}" - intf = Interface.create(name=name, version=version, mode=self.component) + # We only create the interface on the first run, in subsequent runs we + # re-use them so we can cross reference correctly + if self.run > 1: + intf = self.interface_by_name(name) + else: + intf = Interface.create(name=name, version=version, mode=self.component) + self.interfaces.append(intf) + self.current_interface = intf - self.interfaces.append(intf) - elif element == "request": + + # first run only parses interfaces + if self.run <= 1: + return + + if element == "request": if self.current_interface is None: raise XmlError( *self.location, @@ -419,6 +452,13 @@ class Protocol(xml.sax.handler.ContentHandler): name = attrs["name"] sig = attrs["type"] summary = attrs.get("summary", "") + interface_name = attrs.get("interface", None) + if interface_name is not None: + if interface_name.startswith("ei"): + interface_name = f"{self.component}{interface_name[2:]}" + interface = self.interface_by_name(interface_name) + else: + interface = None enum = attrs.get("enum", None) if enum is not None and enum not in [ e.name for e in self.current_interface.enums @@ -427,7 +467,13 @@ class Protocol(xml.sax.handler.ContentHandler): *self.location, f"Failed to find enum '{self.current_interface.name}.{enum}'", ) - arg = Argument.create(name=name, signature=sig, summary=summary, enum=enum) + arg = Argument.create( + name=name, + signature=sig, + summary=summary, + enum=enum, + interface=interface, + ) self.current_message.add_argument(arg) elif element == "entry": if self.current_interface is None: @@ -452,7 +498,12 @@ class Protocol(xml.sax.handler.ContentHandler): if name == "interface": assert self.current_interface is not None self.current_interface = None - elif name == "request": + + # first run only parses interfaces + if self.run <= 1: + return + + if name == "request": assert isinstance(self.current_message, Request) self.current_message = None elif name == "event": @@ -474,6 +525,8 @@ class Protocol(xml.sax.handler.ContentHandler): def parse(protofile: Path, component: str) -> Protocol: proto = Protocol.create(component=component) xml.sax.parse(os.fspath(protofile), proto) + # We parse two times, once to fetch all the interfaces, then to parse the details + xml.sax.parse(os.fspath(protofile), proto) return proto