diff --git a/src/brei-shared.c b/src/brei-shared.c index fa1fccd..0497dd7 100644 --- a/src/brei-shared.c +++ b/src/brei-shared.c @@ -203,13 +203,27 @@ brei_demarshal(struct brei_context *brei, nargs = 0; while (*s) { + uint32_t remaining = end - p; + switch (*s) { case 'i': case 'u': case 'f': + if (remaining < 1) { + return brei_result_new( + BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "Message truncated, need 4 bytes but only %u remaining", + remaining * 4); + } arg->u = *p++; break; case 'x': + if (remaining < 2) { + return brei_result_new( + BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "Message truncated, need 8 bytes but only %u remaining", + remaining * 4); + } arg->x = *(int64_t *)p; p++; p++; @@ -217,6 +231,12 @@ brei_demarshal(struct brei_context *brei, case 'o': case 'n': case 't': + if (remaining < 2) { + return brei_result_new( + BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "Message truncated, need 8 bytes but only %u remaining", + remaining * 4); + } memcpy(&arg->x, p, sizeof(arg->x)); p++; p++; @@ -225,8 +245,12 @@ brei_demarshal(struct brei_context *brei, arg->h = iobuf_take_fd(buf); break; case 's': { + if (remaining < 1) { + return brei_result_new( + BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "Message truncated, need string length but 0 bytes remaining"); + } uint32_t slen = *p; - uint32_t remaining = end - p; uint32_t protolen = brei_string_proto_length(slen); /* in bytes */ uint32_t len32 = protolen / 4; /* p and end are uint32_t* */ diff --git a/test/test_protocol.py b/test/test_protocol.py index 1d6d61b..0890af0 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -1417,3 +1417,164 @@ class TestEiProtocol: assert status.reason == EiConnection.EiDisconnectReason.PROTOCOL, ( status.explanation ) + + @pytest.mark.parametrize( + "truncated_payload", + ( + # (description, object_id, opcode, payload_size) + # ei_handshake.handshake_version: opcode 0, sig "u", needs 4 bytes + ("handshake-u-empty", 0, 0, 0), + # ei_handshake.context_type: opcode 2, sig "u", give 0 of 4 + ("handshake-context_type-empty", 0, 2, 0), + # ei_handshake.name: opcode 3, sig "s", give 0 of 4+ (string length) + ("handshake-name-empty", 0, 3, 0), + # ei_handshake.interface_version: opcode 4, sig "su", + # give 4 bytes (only string length prefix, no string data or uint) + ("handshake-interface_version-partial", 0, 4, 4), + ), + ids=lambda t: t[0], + ) + def test_truncated_message_payload(self, eis, truncated_payload): + """ + Ensure the server disconnects us if we send a message with a valid + msglen but a payload that is too short for the opcode's signature. + This tests the bounds checks in brei_demarshal(). + """ + desc, object_id, opcode, payload_size = truncated_payload + + ei = eis.ei + ei.dispatch() + + # The handshake object (id=0) is always available immediately and + # accepts requests without needing a full connection, making it the + # simplest target for testing demarshalling bounds checks. + header_size = 16 + msglen = header_size + payload_size + + # Craft a raw message: valid header with object_id, msglen and opcode, + # followed by payload_size bytes of zeros (truncated payload). + raw_msg = struct.pack("=QII", object_id, msglen, opcode) + if payload_size > 0: + raw_msg += b"\x00" * payload_size + + try: + ei.send(raw_msg) + ei.dispatch() + time.sleep(0.5) + ei.dispatch() + # Try sending more data to detect if the connection was closed + ei.send(raw_msg) + time.sleep(0.1) + ei.dispatch() + except (ConnectionResetError, BrokenPipeError): + # The server closed the connection - this is the expected outcome + # for a protocol violation + return + + # If we're still connected, check whether the server sent us a + # Disconnected event (it may have established a connection first + # and sent us a protocol-error disconnection) + if ei.connection is not None: + for call in ei.connection.calllog: + if call.name == "Disconnected": + assert ( + call.args["reason"] == EiConnection.EiDisconnectReason.PROTOCOL + ) + return + + # If no connection was established and we didn't get a pipe error, + # that's acceptable too - the server may have simply dropped us + # silently before establishing a connection object + + @pytest.mark.parametrize( + "truncated_payload", + ( + # (description, target, opcode, payload_size) + # ei_connection.sync: opcode 0, sig "nu", needs 12 bytes (8+4), + # send 0 to truncate the new_id (n, 8 bytes) + ("connection-sync-empty", "connection", 0, 0), + # same but send only 4 bytes: enough for a u but not for an n + ("connection-sync-partial-n", "connection", 0, 4), + # ei_seat.bind: opcode 1, sig "t", needs 8 bytes, + # send 0 to truncate the uint64 + ("seat-bind-empty", "seat", 1, 0), + # same but send 4 bytes: half of the required 8 + ("seat-bind-partial-t", "seat", 1, 4), + # ei_device.frame: opcode 3, sig "ut", needs 12 bytes (4+8), + # send 4: has the uint32 but truncates the uint64 + ("device-frame-partial-t", "device", 3, 4), + ), + ids=lambda t: t[0], + ) + def test_truncated_payload_after_connect(self, eis, truncated_payload): + """ + Ensure the server disconnects us if we send a truncated message payload + on connection, seat, or device objects. This tests the brei_demarshal() + bounds checks for 64-bit types (n, t) on real protocol objects. + """ + desc, target, opcode, payload_size = truncated_payload + + ei = eis.ei + ei.dispatch() + ei.init_default_sender_connection() + ei.wait_for_connection() + + assert ei.connection is not None + connection = ei.connection + + @dataclass + class Status: + disconnected: bool = False + reason: int = 0 + explanation: Optional[str] = None + + status = Status() + + def on_disconnected(connection, last_serial, reason, explanation): + status.disconnected = True + status.reason = reason + status.explanation = explanation + + connection.connect("Disconnected", on_disconnected) + + if target == "connection": + object_id = connection.object_id + elif target == "seat": + ei.wait_for_seat() + assert ei.seats, "No seat received" + object_id = ei.seats[0].object_id + elif target == "device": + ei.wait_for_seat() + seat = ei.seats[0] + ei.send( + seat.Bind( + seat.bind_mask([InterfaceName.EI_POINTER, InterfaceName.EI_BUTTON]) + ) + ) + # Wait for a device to appear + ei.wait_for(lambda: ei.find_objects_by_interface(InterfaceName.EI_DEVICE)) + devices = ei.find_objects_by_interface(InterfaceName.EI_DEVICE) + assert devices, "No device received" + object_id = devices[0].object_id + else: + assert False, f"Unknown target {target}" + + header_size = 16 + msglen = header_size + payload_size + + raw_msg = struct.pack("=QII", object_id, msglen, opcode) + if payload_size > 0: + raw_msg += b"\x00" * payload_size + + try: + ei.send(raw_msg) + ei.dispatch() + time.sleep(0.5) + ei.dispatch() + except (ConnectionResetError, BrokenPipeError): + return + + assert status.disconnected, f"Expected disconnection for truncated {desc}" + assert status.reason == EiConnection.EiDisconnectReason.PROTOCOL, ( + status.explanation + )