diff --git a/src/brei-shared.c b/src/brei-shared.c index 6846bbf..fa1fccd 100644 --- a/src/brei-shared.c +++ b/src/brei-shared.c @@ -396,8 +396,14 @@ brei_dispatch(struct brei_context *brei, const struct brei_header *header = (const struct brei_header *)data; uint32_t msglen = header->msglen; - if (len < msglen) - break; + /* Max message size: 1MiB, plenty enough for the current protocol */ + static const uint32_t max_msglen = 1024 * 1024; + if (msglen < headersize || msglen > max_msglen) { + result = brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "invalid message length %u", + msglen); + goto error; + } object_id_t object_id = header->sender_id; uint32_t opcode = header->opcode; diff --git a/test/test_protocol.py b/test/test_protocol.py index 60420b3..1d6d61b 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -33,6 +33,7 @@ import time import shlex import signal import socket +import struct import structlog try: @@ -1359,3 +1360,60 @@ class TestEiProtocol: ei.wait_for(lambda: status.disconnected) assert status.disconnected is True + + @pytest.mark.parametrize( + "invalid_msglen", + (0, 1, 15, 1024 * 1024 + 1, 0xFFFFFFFF), + ids=("zero", "one", "header-minus-one", "over-max", "uint32-max"), + ) + def test_invalid_message_length(self, eis, invalid_msglen): + """ + Ensure the server disconnects us if we send a message with an invalid + msglen (less than header size or greater than 1MiB). + """ + ei = eis.ei + ei.dispatch() + ei.init_default_sender_connection() + ei.wait_for_connection() + + assert ei.connection is not None + connection = ei.connection + + @dataclass + class Status: + disconnected: bool = False + reason: int = 0 + explanation: Optional[str] = None + + status = Status() + + def on_disconnected(connection, last_serial, reason, explanation): + status.disconnected = True + status.reason = reason + status.explanation = explanation + + connection.connect("Disconnected", on_disconnected) + + # Craft a raw message with a valid object_id (the connection's) but + # an invalid msglen. Use opcode 0 since the msglen check happens + # before opcode validation. + raw_msg = struct.pack("=QII", connection.object_id, invalid_msglen, 0) + # For oversized messages we only need to send the header - the server + # checks msglen from the header before trying to read the full payload. + try: + ei.send(raw_msg) + ei.dispatch() + time.sleep(0.5) + ei.dispatch() + except (ConnectionResetError, BrokenPipeError): + # The server may have already closed the connection + return + + # If we didn't get a socket error, we should have gotten a + # Disconnected event with a protocol error reason + assert status.disconnected, ( + f"Expected disconnection for invalid msglen {invalid_msglen}" + ) + assert status.reason == EiConnection.EiDisconnectReason.PROTOCOL, ( + status.explanation + )