diff --git a/src/app-layer-ssl.c b/src/app-layer-ssl.c index 8c2421de9a..457d5413be 100644 --- a/src/app-layer-ssl.c +++ b/src/app-layer-ssl.c @@ -90,10 +90,12 @@ SslConfig ssl_config; #define SSLV2_MT_CLIENT_CERTIFICATE 8 #define SSLV3_RECORD_LEN 5 +#define SSLV3_MESSAGE_HDR_LEN 4 static void SSLParserReset(SSLState *ssl_state) { ssl_state->bytes_processed = 0; + ssl_state->message_start = 0; } static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input, @@ -114,11 +116,13 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input, switch (ssl_state->bytes_processed) { case 9: ssl_state->bytes_processed++; + parsed++; ssl_state->handshake_client_hello_ssl_version = *(input++) << 8; if (--input_len == 0) break; case 10: ssl_state->bytes_processed++; + parsed++; ssl_state->handshake_client_hello_ssl_version |= *(input++); if (--input_len == 0) break; @@ -132,6 +136,7 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input, if (rc >= 0) { ssl_state->bytes_processed += rc; input += rc; + parsed += rc; } break; @@ -164,17 +169,14 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input, rc = DecodeTLSHandshakeServerCertificate(ssl_state, ssl_state->trec, ssl_state->trec_pos); if (rc > 0) { - ssl_state->bytes_processed += rc; - input += rc; - } - if (rc == 0) { - /* packet is incomplete - do not mark as parsed */ - } - if (rc < 0) { - /* error, skip packet */ - parsed += input_len; - ssl_state->bytes_processed += input_len; - return parsed; + /* do not return normally if the packet was fragmented: + * we would return the size of the *entire* message, + * while we expect only the number of bytes parsed bytes + * from the *current* fragment + */ + uint32_t diff = input_len - (ssl_state->trec_pos - rc); + ssl_state->bytes_processed += diff; + return diff; } break; case SSLV3_HS_HELLO_REQUEST: @@ -188,74 +190,69 @@ static int SSLv3ParseHandshakeType(SSLState *ssl_state, uint8_t *input, break; } - /* looks like we have another record */ - parsed += (input - initial_input); - if ((input_len + ssl_state->bytes_processed) >= ssl_state->record_length + SSLV3_RECORD_LEN) { - uint32_t diff = ssl_state->record_length + SSLV3_RECORD_LEN - ssl_state->bytes_processed; - parsed += diff; - ssl_state->bytes_processed += diff; - return parsed; - - /* we still don't have the entire record for the one we are - * currently parsing */ - } else { + /* skip the rest of the current message */ + uint32_t next_msg_offset = ssl_state->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->message_length; + if (ssl_state->bytes_processed + input_len < next_msg_offset) { + /* we don't have enough data */ parsed += input_len; ssl_state->bytes_processed += input_len; return parsed; } + uint32_t diff = next_msg_offset - ssl_state->bytes_processed; + parsed += diff; + ssl_state->bytes_processed += diff; + return parsed; } static int SSLv3ParseHandshakeProtocol(SSLState *ssl_state, uint8_t *input, uint32_t input_len) { uint8_t *initial_input = input; + int retval; if (input_len == 0) { return 0; } - switch (ssl_state->bytes_processed) { - case 5: - if (input_len >= 4) { - ssl_state->handshake_type = *(input++); - // XXX we should *not* skip the next 3 bytes, they contain the Message length - input += 3; - input_len -= 4; - ssl_state->bytes_processed += 4; + if (ssl_state->message_start == 0) { + ssl_state->message_start = SSLV3_RECORD_LEN; + } + + switch (ssl_state->bytes_processed - ssl_state->message_start) { + case 0: + ssl_state->handshake_type = *(input++); + ssl_state->bytes_processed++; + if (--input_len == 0) break; - } else { - ssl_state->handshake_type = *(input++); - ssl_state->bytes_processed++; - if (--input_len == 0) - break; - } - case 6: + case 1: + ssl_state->message_length = *(input++) << 16; ssl_state->bytes_processed++; - input++; if (--input_len == 0) break; - case 7: + case 2: + ssl_state->message_length |= *(input++) << 8; ssl_state->bytes_processed++; - input++; if (--input_len == 0) break; - case 8: + case 3: + ssl_state->message_length |= *(input++); ssl_state->bytes_processed++; - input++; if (--input_len == 0) break; } - if (input_len == 0) - return (input - initial_input); - - int retval = SSLv3ParseHandshakeType(ssl_state, input, input_len); - if (retval == -1) { - SCReturnInt(-1); - } else { - input += retval; - return (input - initial_input); + retval = SSLv3ParseHandshakeType(ssl_state, input, input_len); + if (retval < 0) { + SCReturnInt(retval); + } + uint32_t next_msg_offset = ssl_state->message_start + SSLV3_MESSAGE_HDR_LEN + ssl_state->message_length; + if (ssl_state->bytes_processed >= next_msg_offset) { + ssl_state->handshake_type = 0; + ssl_state->message_length = 0; + ssl_state->message_start = next_msg_offset; } + input += retval; + return (input - initial_input); } static int SSLv3ParseRecord(uint8_t direction, SSLState *ssl_state, @@ -702,6 +699,12 @@ static int SSLv3Decode(uint8_t direction, SSLState *ssl_state, SCLogDebug("Error parsing SSLv3.x. Let's get outta here"); return -1; } else { + if ((uint32_t)retval > input_len) { + SCLogDebug("Error parsing SSLv3.x. Reseting parser " + "state. Let's get outta here"); + SSLParserReset(ssl_state); + return -1; + } parsed += retval; input_len -= retval; if (ssl_state->bytes_processed == ssl_state->record_length + SSLV3_RECORD_LEN) { @@ -803,6 +806,11 @@ static int SSLDecode(uint8_t direction, void *alstate, AppLayerParserState *psta } else { input_len -= retval; input += retval; + if (ssl_state->bytes_processed == SSLV3_RECORD_LEN + && ssl_state->record_length == 0) { + /* empty record */ + SSLParserReset(ssl_state); + } } } @@ -830,14 +838,24 @@ static int SSLDecode(uint8_t direction, void *alstate, AppLayerParserState *psta "previously left off"); retval = SSLv3Decode(direction, ssl_state, pstate, input, input_len); - if (retval == -1) { + if (retval < 0) { SCLogDebug("Error parsing SSLv3.x. Reseting parser " "state. Let's get outta here"); SSLParserReset(ssl_state); return 0; } else { + if ((uint32_t)retval > input_len) { + SCLogDebug("Error parsing SSLv3.x. Reseting parser " + "state. Let's get outta here"); + SSLParserReset(ssl_state); + } input_len -= retval; input += retval; + if (ssl_state->bytes_processed == SSLV3_RECORD_LEN + && ssl_state->record_length == 0) { + /* empty record */ + SSLParserReset(ssl_state); + } } } diff --git a/src/app-layer-ssl.h b/src/app-layer-ssl.h index c8aaed1f9f..7853b3f567 100644 --- a/src/app-layer-ssl.h +++ b/src/app-layer-ssl.h @@ -69,6 +69,10 @@ typedef struct SSLState_ { /* record length's length for SSLv2 */ uint32_t record_lengths_length; + /* offset of the beginning of the current message (including header) */ + uint32_t message_start; + uint32_t message_length; + /* holds some state flags we need */ uint32_t flags; diff --git a/src/app-layer-tls-handshake.c b/src/app-layer-tls-handshake.c index 56dc36e75c..88282ca0a7 100644 --- a/src/app-layer-tls-handshake.c +++ b/src/app-layer-tls-handshake.c @@ -84,8 +84,7 @@ int DecodeTLSHandshakeServerHello(SSLState *ssl_state, uint8_t *input, uint32_t SCLogDebug("TLS Handshake Version %.4x Cipher %d Compression %d\n", version, ciphersuite, compressionmethod); - /* return the message length (TLS record - (handshake type + length)) */ - return ssl_state->record_length-4; + return ssl_state->message_length; } int DecodeTLSHandshakeServerCertificate(SSLState *ssl_state, uint8_t *input, uint32_t input_len)