libeis: check incoming objects' version for correctness

If the server sends a protocol version higher than we support, fail.
This commit is contained in:
Peter Hutterer 2023-05-26 16:20:06 +10:00
parent c99f4ffa2c
commit 1aedabe7c7
3 changed files with 50 additions and 0 deletions

View file

@ -323,6 +323,7 @@ client_msg_sync(struct eis_connection *connection, object_id_t new_id, uint32_t
struct eis_client *client = eis_connection_get_client(connection);
DISCONNECT_IF_INVALID_ID(client, new_id);
DISCONNECT_IF_INVALID_VERSION(client, ei_connection, new_id, version);
struct eis_callback *callback = eis_callback_new(client, new_id, version);
log_debug(eis_client_get_context(client) , "object %#" PRIx64 ": connection sync done", new_id);

View file

@ -169,3 +169,14 @@ eis_log_msg_va(struct eis *eis,
eis_log_msg((T_), EIS_LOG_PRIORITY_ERROR, __FILE__, __LINE__, __func__, "🪳 libeis bug: " __VA_ARGS__)
#define log_bug_client(T_, ...) \
eis_log_msg((T_), EIS_LOG_PRIORITY_ERROR, __FILE__, __LINE__, __func__, "🪲 Bug: " __VA_ARGS__)
#define DISCONNECT_IF_INVALID_VERSION(eis_client_, intf_, id_, version_) do { \
struct eis_client *_client = (eis_client_); \
uint32_t _version = (version_); \
uint64_t _id = (id_); \
if (_client->interface_versions.intf_ < _version) { \
struct eis *_eis = eis_client_get_context(_client); \
log_bug(_eis, "Received invalid version %u for object id %#" PRIx64 ". Disconnecting", _version, _id); \
return brei_result_new(EIS_CONNECTION_DISCONNECT_REASON_PROTOCOL, "Received invalid version %u for object id %#" PRIx64 ".", _version, _id); \
} \
} while(0)

View file

@ -947,6 +947,44 @@ class TestEiProtocol:
), status.explanation
assert status.explanation is not None
def test_invalid_callback_version(self, eis):
"""
Expect to get disconnected if we allocate a client object outside the agreed version range.
Right now only callbacks are client-created, so that's all we can test here.
"""
ei = eis.ei
ei.dispatch()
ei.init_default_sender_connection()
ei.dispatch()
ei.wait_for_connection()
@attr.s
class Status:
disconnected: bool = attr.ib(default=False)
reason: int = attr.ib(default=0)
explanation: Optional[str] = attr.ib(default=None)
status = Status()
def on_disconnected(connection, last_serial, reason, explanation):
status.disconnected = True
status.reason = reason
status.explanation = explanation
ei.connection.connect("Disconnected", on_disconnected)
cb = EiCallback.create(0x100, VERSION_V(100))
ei.context.register(cb)
ei.send(ei.connection.Sync(cb.object_id, cb.version))
ei.wait_for(lambda: status.disconnected)
assert status.disconnected
assert (
status.reason == EiConnection.EiDisconnectReason.PROTOCOL
), status.explanation
assert status.explanation is not None
@pytest.mark.parametrize(
"wanted_interface",
(