ei-scanner: expose version_arg and version_arg_for

Points to the correspoding "version" argument, or points back to the
argument this version argument is for.
This commit is contained in:
Peter Hutterer 2023-05-25 13:09:27 +10:00
parent 1d8cd84c56
commit 552f6dcbd0
3 changed files with 80 additions and 4 deletions

View file

@ -82,6 +82,16 @@ class Argument:
For an argument referenced by another argument through "interface_name", this field
points to the other argument that references this argument.
"""
version_arg: Optional["Argument"] = attr.ib(default=None)
"""
For an argument with type "new_id", this field points to the argument that
contains the version for this new object.
"""
version_arg_for: Optional["Argument"] = attr.ib(default=None)
"""
For an argument referenced by another argument of type "new_id", this field
points to the other argument that references this argument.
"""
@property
def signature(self) -> str:
@ -465,6 +475,7 @@ class ProtocolParser(xml.sax.handler.ContentHandler):
current_description: Optional[Description] = attr.ib(init=False, default=None)
# A dict of arg name to interface_arg name mappings
current_interface_arg_names: Dict[str, str] = attr.ib(init=False, default=attr.Factory(dict)) # type: ignore
current_new_id_arg: Optional[Argument] = attr.ib(init=False, default=None)
_run_counter: int = attr.ib(init=False, default=0, repr=False)
@ -706,6 +717,13 @@ class ProtocolParser(xml.sax.handler.ContentHandler):
interface=interface,
)
self.current_message.add_argument(arg)
if proto_type == "new_id":
if self.current_new_id_arg is not None:
raise XmlError.create(
f"Multiple args of type '{proto_type}' for '{self.current_interface.name}.{self.current_message.name}'",
self.location,
)
self.current_new_id_arg = arg
elif element == "entry":
if self.current_interface is None:
raise XmlError.create(
@ -762,6 +780,7 @@ class ProtocolParser(xml.sax.handler.ContentHandler):
# Populate `interface_arg` and `interface_arg_for`, now we have all arguments
if name in ["request", "event"]:
assert isinstance(self.current_message, Message)
assert isinstance(self.current_interface, Interface)
# obj is the argument of type object that the interface applies to
# iname is the argument of type "interface_name" that specifies the interface
for obj, iname in self.current_interface_arg_names.items():
@ -774,6 +793,24 @@ class ProtocolParser(xml.sax.handler.ContentHandler):
obj_arg.interface_arg = iname_arg
iname_arg.interface_arg_for = obj_arg
self.current_interface_arg_names = {}
if self.current_new_id_arg is not None:
arg = self.current_new_id_arg
version_arg = self.current_message.find_argument("version")
if version_arg is None:
# Sigh, protocol bug: ei_connection.sync one doesn't have a version arg
if (
f"{self.current_interface.plainname}.{self.current_message.name}"
!= "connection.sync"
):
raise XmlError.create(
f"Unable to find a version argument for {self.current_interface.plainname}.{self.current_message.name}::{arg.name}",
self.location,
)
else:
arg.version_arg = version_arg
version_arg.version_arg_for = arg
self.current_new_id_arg = None
if name == "request":
assert isinstance(self.current_message, Request)
self.current_message = None

View file

@ -275,11 +275,9 @@ class {{interface.camel_name}}(Interface):
{% for arg in incoming.arguments %}
{% if arg.signature == "n" %}
{% if arg.interface %}
{# Note: this only works while the version argument is always called version #}
"{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, version),
"{{arg.name}}": {{arg.interface.camel_name}}.create({{arg.name}}, {{arg.version_arg.name}}),
{% else %}
{# Note: this only works while the argument is called interface_name #}
"{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, version),
"{{arg.name}}": Interface.lookup(interface_name).create({{arg.name}}, {{arg.version_arg.name}}),
{% endif %}
{% endif %}
{% endfor %}

View file

@ -58,6 +58,47 @@ class TestScanner:
assert version.interface_arg is None
assert version.interface_arg_for is None
def iterate_args(self, protocol: Protocol):
for interface in protocol.interfaces:
for request in interface.requests:
for arg in request.arguments:
yield interface, request, arg
for event in interface.events:
for arg in event.arguments:
yield interface, event, arg
@pytest.mark.parametrize("component", ("ei",))
def test_versione_arg(self, protocol: Protocol):
for interface, message, arg in self.iterate_args(protocol):
if arg.protocol_type == "new_id":
if f"{interface.plainname}.{message.name}" not in [
"connection.sync",
]:
assert (
arg.version_arg is not None
), f"{interface.name}.{message.name}::{arg.name}"
assert (
arg.version_arg.name == "version"
), f"{interface.name}.{message.name}::{arg.name}"
elif arg.name == "version":
if f"{interface.plainname}.{message.name}" not in [
"handshake.handshake_version",
"handshake.interface_version",
]:
assert (
arg.version_arg_for is not None
), f"{interface.name}.{message.name}::{arg.name}"
assert (
arg.version_arg_for.name != "version"
), f"{interface.name}.{message.name}::{arg.name}"
else:
assert (
arg.version_arg is None
), f"{interface.name}.{message.name}::{arg.name}"
assert (
arg.version_arg_for is None
), f"{interface.name}.{message.name}::{arg.name}"
@pytest.mark.parametrize("method", ("yamlfile", "jsonfile", "string"))
def test_cli_extra_data(self, tmp_path, method):
result_path = tmp_path / "result"