diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index c7c158761c..94ddc676b9 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -55,9 +55,22 @@ impl ParseError for PgsqlParseError { } } -fn parse_length(i: &[u8]) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { +fn parse_gte_length(i: &[u8], expected_length: u32) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| { - x >= PGSQL_LENGTH_FIELD + x >= expected_length + })(i); + match res { + Ok(result) => Ok((result.0, result.1)), + Err(nom7::Err::Incomplete(needed)) => Err(Err::Incomplete(needed)), + Err(_) => Err(Err::Error(PgsqlParseError::InvalidLength)), + } +} + +fn parse_exact_length( + i: &[u8], expected_length: u32, +) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { + let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| { + x == expected_length })(i); match res { Ok(result) => Ok((result.0, result.1)), @@ -612,7 +625,7 @@ fn parse_sasl_initial_response_payload( pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = map_parser( take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload, @@ -631,7 +644,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, P pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let resp = PgsqlFEMessage::SASLResponse(RegularPacket { identifier, @@ -698,7 +711,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, Pg // Password can be encrypted or in cleartext pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; Ok(( i, @@ -712,7 +725,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlP fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; Ok(( i, @@ -735,7 +748,7 @@ fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseEr fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?; Ok(( i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }), @@ -751,7 +764,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError b'X' => parse_terminate_message(i)?, _ => { let (i, identifier) = be_u8(i)?; - let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket { identifier, @@ -766,7 +779,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage, PgsqlParseError<&'a [u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?; - let (i, length) = verify(be_u32, |&x| x >= 8)(i)?; + let (i, length) = parse_gte_length(i, 8)?; let (i, auth_type) = be_u32(i)?; let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| { match auth_type { @@ -849,7 +862,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, param) = map_parser( take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter, @@ -874,7 +887,7 @@ pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?; - let (i, length) = verify(be_u32, |&x| x == 12)(i)?; + let (i, length) = parse_exact_length(i, 12)?; let (i, pid) = be_u32(i)?; let (i, secret_key) = be_u32(i)?; Ok(( @@ -890,7 +903,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pg fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?; Ok(( i, @@ -904,7 +917,7 @@ fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?; - let (i, length) = verify(be_u32, |&x| x == 5)(i)?; + let (i, length) = parse_exact_length(i, 5)?; let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?; Ok(( i, @@ -941,7 +954,7 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField, PgsqlParseError<&[u8]>> pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?; - let (i, length) = verify(be_u32, |&x| x > 6)(i)?; + let (i, length) = parse_gte_length(i, 7)?; let (i, field_count) = be_u16(i)?; let (i, fields) = map_parser( take(length - 6), @@ -992,7 +1005,7 @@ fn add_up_data_size(columns: Vec) -> u64 { // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?; - let (i, length) = verify(be_u32, |&x| x >= 6)(i)?; + let (i, length) = parse_gte_length(i, 7)?; let (i, field_count) = be_u16(i)?; // 6 here is for skipping length + field_count let (i, rows) = map_parser( @@ -1109,7 +1122,7 @@ pub fn parse_error_notice_fields( fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?; - let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, length) = parse_gte_length(i, 11)?; let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { parse_error_notice_fields(b, true) })(i)?; @@ -1126,7 +1139,7 @@ fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlP fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?; - let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, length) = parse_gte_length(i, 11)?; let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { parse_error_notice_fields(b, false) })(i)?; @@ -1143,7 +1156,7 @@ fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pgsql fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?; // length (u32) + pid (u32) + at least one byte, for we have two str fields - let (i, length) = verify(be_u32, |&x| x > 9)(i)?; + let (i, length) = parse_gte_length(i, 10)?; let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { let (b, pid) = be_u32(b)?; let (b, channel_name) = take_until_and_consume(b"\x00")(b)?; @@ -1175,7 +1188,7 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar b'D' => parse_consolidated_data_row(i)?, _ => { let (i, identifier) = be_u8(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket { identifier,