From 552f6dcbd037b5f271a0a5b3a792f23b4d32220f Mon Sep 17 00:00:00 2001 From: Peter Hutterer Date: Thu, 25 May 2023 13:09:27 +1000 Subject: [PATCH] ei-scanner: expose `version_arg` and `version_arg_for` Points to the correspoding "version" argument, or points back to the argument this version argument is for. --- proto/ei-scanner | 37 +++++++++++++++++++++++++++++++++++++ test/eiproto.py.tmpl | 6 ++---- test/test_scanner.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) diff --git a/proto/ei-scanner b/proto/ei-scanner index 5c209f8..4aefd0a 100755 --- a/proto/ei-scanner +++ b/proto/ei-scanner @@ -82,6 +82,16 @@ class Argument: For an argument referenced by another argument through "interface_name", this field points to the other argument that references this argument. """ + version_arg: Optional["Argument"] = attr.ib(default=None) + """ + For an argument with type "new_id", this field points to the argument that + contains the version for this new object. + """ + version_arg_for: Optional["Argument"] = attr.ib(default=None) + """ + For an argument referenced by another argument of type "new_id", this field + points to the other argument that references this argument. + """ @property def signature(self) -> str: @@ -465,6 +475,7 @@ class ProtocolParser(xml.sax.handler.ContentHandler): current_description: Optional[Description] = attr.ib(init=False, default=None) # A dict of arg name to interface_arg name mappings current_interface_arg_names: Dict[str, str] = attr.ib(init=False, default=attr.Factory(dict)) # type: ignore + current_new_id_arg: Optional[Argument] = attr.ib(init=False, default=None) _run_counter: int = attr.ib(init=False, default=0, repr=False) @@ -706,6 +717,13 @@ class ProtocolParser(xml.sax.handler.ContentHandler): interface=interface, ) self.current_message.add_argument(arg) + if proto_type == "new_id": + if self.current_new_id_arg is not None: + raise XmlError.create( + f"Multiple args of type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}'", + self.location, + ) + self.current_new_id_arg = arg elif element == "entry": if self.current_interface is None: raise XmlError.create( @@ -762,6 +780,7 @@ class ProtocolParser(xml.sax.handler.ContentHandler): # Populate `interface_arg` and `interface_arg_for`, now we have all arguments if name in ["request", "event"]: assert isinstance(self.current_message, Message) + assert isinstance(self.current_interface, Interface) # obj is the argument of type object that the interface applies to # iname is the argument of type "interface_name" that specifies the interface for obj, iname in self.current_interface_arg_names.items(): @@ -774,6 +793,24 @@ class ProtocolParser(xml.sax.handler.ContentHandler): obj_arg.interface_arg = iname_arg iname_arg.interface_arg_for = obj_arg self.current_interface_arg_names = {} + + if self.current_new_id_arg is not None: + arg = self.current_new_id_arg + version_arg = self.current_message.find_argument("version") + if version_arg is None: + # Sigh, protocol bug: ei_connection.sync one doesn't have a version arg + if ( + f"{self.current_interface.plainname}.{self.current_message.name}" + != "connection.sync" + ): + raise XmlError.create( + f"Unable to find a version argument for {self.current_interface.plainname}.{self.current_message.name}::{arg.name}", + self.location, + ) + else: + arg.version_arg = version_arg + version_arg.version_arg_for = arg + self.current_new_id_arg = None if name == "request": assert isinstance(self.current_message, Request) self.current_message = None diff --git a/test/eiproto.py.tmpl b/test/eiproto.py.tmpl index 77e84ce..9769e07 100644 --- a/test/eiproto.py.tmpl +++ b/test/eiproto.py.tmpl @@ -275,11 +275,9 @@ class {{interface.camel_name}}(Interface): {% for arg in incoming.arguments %} {% if arg.signature == "n" %} {% if arg.interface %} - {# Note: this only works while the version argument is always called version #} - "{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, version), + "{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, {{arg.version_arg.name}}), {% else %} - {# Note: this only works while the argument is called interface_name #} - "{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, version), + "{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, {{arg.version_arg.name}}), {% endif %} {% endif %} {% endfor %} diff --git a/test/test_scanner.py b/test/test_scanner.py index 841490a..de190a3 100644 --- a/test/test_scanner.py +++ b/test/test_scanner.py @@ -58,6 +58,47 @@ class TestScanner: assert version.interface_arg is None assert version.interface_arg_for is None + def iterate_args(self, protocol: Protocol): + for interface in protocol.interfaces: + for request in interface.requests: + for arg in request.arguments: + yield interface, request, arg + for event in interface.events: + for arg in event.arguments: + yield interface, event, arg + + @pytest.mark.parametrize("component", ("ei",)) + def test_versione_arg(self, protocol: Protocol): + for interface, message, arg in self.iterate_args(protocol): + if arg.protocol_type == "new_id": + if f"{interface.plainname}.{message.name}" not in [ + "connection.sync", + ]: + assert ( + arg.version_arg is not None + ), f"{interface.name}.{message.name}::{arg.name}" + assert ( + arg.version_arg.name == "version" + ), f"{interface.name}.{message.name}::{arg.name}" + elif arg.name == "version": + if f"{interface.plainname}.{message.name}" not in [ + "handshake.handshake_version", + "handshake.interface_version", + ]: + assert ( + arg.version_arg_for is not None + ), f"{interface.name}.{message.name}::{arg.name}" + assert ( + arg.version_arg_for.name != "version" + ), f"{interface.name}.{message.name}::{arg.name}" + else: + assert ( + arg.version_arg is None + ), f"{interface.name}.{message.name}::{arg.name}" + assert ( + arg.version_arg_for is None + ), f"{interface.name}.{message.name}::{arg.name}" + @pytest.mark.parametrize("method", ("yamlfile", "jsonfile", "string")) def test_cli_extra_data(self, tmp_path, method): result_path = tmp_path / "result"