From cc841e66dba301fc2da2861a0e3bd7996e624fa0 Mon Sep 17 00:00:00 2001 From: Juliana Fajardini Date: Mon, 17 Feb 2025 19:13:50 -0300 Subject: [PATCH] pgsql/parser: always use fn for parsing PDU length Some inner parsers were using it, some weren't. Better to standardize this. Also take the time to avoid magic numbers for representing the expected lengths for pgsql PDUs. Also throwing PgsqlParseError and allowing for incomplete results. Related to Task #5566 Bug #5524 --- rust/src/pgsql/parser.rs | 51 +++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index c7c158761c..94ddc676b9 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -55,9 +55,22 @@ impl ParseError for PgsqlParseError { } } -fn parse_length(i: &[u8]) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { +fn parse_gte_length(i: &[u8], expected_length: u32) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| { - x >= PGSQL_LENGTH_FIELD + x >= expected_length + })(i); + match res { + Ok(result) => Ok((result.0, result.1)), + Err(nom7::Err::Incomplete(needed)) => Err(Err::Incomplete(needed)), + Err(_) => Err(Err::Error(PgsqlParseError::InvalidLength)), + } +} + +fn parse_exact_length( + i: &[u8], expected_length: u32, +) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> { + let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| { + x == expected_length })(i); match res { Ok(result) => Ok((result.0, result.1)), @@ -612,7 +625,7 @@ fn parse_sasl_initial_response_payload( pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = map_parser( take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload, @@ -631,7 +644,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, P pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let resp = PgsqlFEMessage::SASLResponse(RegularPacket { identifier, @@ -698,7 +711,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, Pg // Password can be encrypted or in cleartext pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; Ok(( i, @@ -712,7 +725,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlP fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; Ok(( i, @@ -735,7 +748,7 @@ fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseEr fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?; Ok(( i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }), @@ -751,7 +764,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError b'X' => parse_terminate_message(i)?, _ => { let (i, identifier) = be_u8(i)?; - let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket { identifier, @@ -766,7 +779,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage, PgsqlParseError<&'a [u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?; - let (i, length) = verify(be_u32, |&x| x >= 8)(i)?; + let (i, length) = parse_gte_length(i, 8)?; let (i, auth_type) = be_u32(i)?; let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| { match auth_type { @@ -849,7 +862,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, param) = map_parser( take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter, @@ -874,7 +887,7 @@ pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?; - let (i, length) = verify(be_u32, |&x| x == 12)(i)?; + let (i, length) = parse_exact_length(i, 12)?; let (i, pid) = be_u32(i)?; let (i, secret_key) = be_u32(i)?; Ok(( @@ -890,7 +903,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pg fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?; - let (i, length) = parse_length(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?; Ok(( i, @@ -904,7 +917,7 @@ fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?; - let (i, length) = verify(be_u32, |&x| x == 5)(i)?; + let (i, length) = parse_exact_length(i, 5)?; let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?; Ok(( i, @@ -941,7 +954,7 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField, PgsqlParseError<&[u8]>> pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?; - let (i, length) = verify(be_u32, |&x| x > 6)(i)?; + let (i, length) = parse_gte_length(i, 7)?; let (i, field_count) = be_u16(i)?; let (i, fields) = map_parser( take(length - 6), @@ -992,7 +1005,7 @@ fn add_up_data_size(columns: Vec) -> u64 { // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?; - let (i, length) = verify(be_u32, |&x| x >= 6)(i)?; + let (i, length) = parse_gte_length(i, 7)?; let (i, field_count) = be_u16(i)?; // 6 here is for skipping length + field_count let (i, rows) = map_parser( @@ -1109,7 +1122,7 @@ pub fn parse_error_notice_fields( fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?; - let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, length) = parse_gte_length(i, 11)?; let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { parse_error_notice_fields(b, true) })(i)?; @@ -1126,7 +1139,7 @@ fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlP fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?; - let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, length) = parse_gte_length(i, 11)?; let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { parse_error_notice_fields(b, false) })(i)?; @@ -1143,7 +1156,7 @@ fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pgsql fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?; // length (u32) + pid (u32) + at least one byte, for we have two str fields - let (i, length) = verify(be_u32, |&x| x > 9)(i)?; + let (i, length) = parse_gte_length(i, 10)?; let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| { let (b, pid) = be_u32(b)?; let (b, channel_name) = take_until_and_consume(b"\x00")(b)?; @@ -1175,7 +1188,7 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar b'D' => parse_consolidated_data_row(i)?, _ => { let (i, identifier) = be_u8(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket { identifier,