From e75fcffa298daa9a58b9d0c497645a6e348f4efb Mon Sep 17 00:00:00 2001 From: Juliana Fajardini Date: Tue, 1 Apr 2025 20:00:07 -0700 Subject: [PATCH] pgsql: add initial support to copy-out subproto This sub-protocol inspects messages exchanged between postgresql backend and frontend after a 'COPY TO STDOUT' has been processed. Parses new messages: - CopyOutResponse -- initiates copy-out mode/sub-protocol - CopyData -- data transfer messages - CopyDone -- signals that no more CopyData messages will be seen from the sender for the current transaction Task #4854 --- etc/schema.json | 19 ++++++++++ rust/src/pgsql/logger.rs | 23 +++++++++++- rust/src/pgsql/parser.rs | 75 +++++++++++++++++++++++++++++++++++----- rust/src/pgsql/pgsql.rs | 31 +++++++++++++++++ 4 files changed, 139 insertions(+), 9 deletions(-) diff --git a/etc/schema.json b/etc/schema.json index a69bcec35e..95b9a90f58 100644 --- a/etc/schema.json +++ b/etc/schema.json @@ -3676,6 +3676,17 @@ "code": { "type": "string" }, + "copy_data_out": { + "type": "object", + "properties": { + "row_count": { + "type": "integer" + }, + "data_size": { + "type": "integer" + } + } + }, "command_completed": { "type": "string" }, @@ -3755,6 +3766,14 @@ "severity_non_localizable": { "type": "string" }, + "copy_out_response": { + "type": "object", + "properties": { + "copy_column_count": { + "type": "integer" + } + } + }, "ssl_accepted": { "type": "boolean" } diff --git a/rust/src/pgsql/logger.rs b/rust/src/pgsql/logger.rs index edf87e4ec4..7fb3343654 100644 --- a/rust/src/pgsql/logger.rs +++ b/rust/src/pgsql/logger.rs @@ -40,6 +40,7 @@ fn log_pgsql(tx: &PgsqlTransaction, flags: u32, js: &mut JsonBuilder) -> Result< } if !tx.responses.is_empty() { + SCLogDebug!("Responses length: {}", tx.responses.len()); js.set_object("response", &log_response_object(tx)?)?; } js.close()?; @@ -197,7 +198,8 @@ fn log_response(res: &PgsqlBEMessage, jb: &mut JsonBuilder) -> Result<(), JsonEr PgsqlBEMessage::AuthenticationOk(_) | PgsqlBEMessage::AuthenticationCleartextPassword(_) | PgsqlBEMessage::AuthenticationSASL(_) - | PgsqlBEMessage::AuthenticationSASLContinue(_) => { + | PgsqlBEMessage::AuthenticationSASLContinue(_) + | PgsqlBEMessage::CopyDone(_) => { jb.set_string("message", res.to_str())?; } PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { @@ -207,6 +209,15 @@ fn log_response(res: &PgsqlBEMessage, jb: &mut JsonBuilder) -> Result<(), JsonEr }) => { // We take care of these elsewhere } + PgsqlBEMessage::CopyOutResponse(CopyOutResponse { + identifier: _, + length: _, + column_cnt, + }) => { + jb.open_object(res.to_str())?; + jb.set_uint("copy_column_count", *column_cnt)?; + jb.close()?; + } PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage { identifier: _, length: _, @@ -223,6 +234,16 @@ fn log_response(res: &PgsqlBEMessage, jb: &mut JsonBuilder) -> Result<(), JsonEr }) => { // We don't want to log this one } + PgsqlBEMessage::ConsolidatedCopyDataOut(ConsolidatedDataRowPacket { + identifier: _, + row_cnt, + data_size, + }) => { + jb.open_object(res.to_str())?; + jb.set_uint("row_count", *row_cnt)?; + jb.set_uint("data_size", *data_size)?; + jb.close()?; + } PgsqlBEMessage::RowDescription(RowDescriptionMessage { identifier: _, length: _, diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index 67029cad5c..034d3bf2e8 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -29,7 +29,7 @@ use nom7::multi::{many1, many_m_n, many_till}; use nom7::number::streaming::{be_i16, be_i32}; use nom7::number::streaming::{be_u16, be_u32, be_u8}; use nom7::sequence::terminated; -use nom7::{Err, IResult}; +use nom7::{Err, IResult, ToUsize}; pub const PGSQL_LENGTH_FIELD: u32 = 4; @@ -247,7 +247,7 @@ pub struct BackendKeyDataMessage { #[derive(Debug, PartialEq, Eq)] pub struct ConsolidatedDataRowPacket { pub identifier: u8, - pub row_cnt: u64, + pub row_cnt: u64, // row or msg cnt pub data_size: u64, } @@ -268,6 +268,21 @@ pub struct NotificationResponse { pub payload: Vec, } +#[derive(Debug, PartialEq, Eq)] +pub struct CopyOutResponse { + pub identifier: u8, + pub length: u32, + pub column_cnt: u16, + // for each column, there are column_cnt u16 format codes received + // for now, we're not storing those +} + +#[derive(Debug, PartialEq, Eq)] +pub struct TerminationMessage { + pub identifier: u8, + pub length: u32, +} + #[derive(Debug, PartialEq, Eq)] pub enum PgsqlBEMessage { SSLResponse(SSLResponseMessage), @@ -283,6 +298,9 @@ pub enum PgsqlBEMessage { ParameterStatus(ParameterStatusMessage), BackendKeyData(BackendKeyDataMessage), CommandComplete(RegularPacket), + CopyOutResponse(CopyOutResponse), + ConsolidatedCopyDataOut(ConsolidatedDataRowPacket), + CopyDone(TerminationMessage), ReadyForQuery(ReadyForQueryMessage), RowDescription(RowDescriptionMessage), ConsolidatedDataRow(ConsolidatedDataRowPacket), @@ -309,6 +327,9 @@ impl PgsqlBEMessage { PgsqlBEMessage::ParameterStatus(_) => "parameter_status", PgsqlBEMessage::BackendKeyData(_) => "backend_key_data", PgsqlBEMessage::CommandComplete(_) => "command_completed", + PgsqlBEMessage::CopyOutResponse(_) => "copy_out_response", + PgsqlBEMessage::ConsolidatedCopyDataOut(_) => "copy_data_out", + PgsqlBEMessage::CopyDone(_) => "copy_done", PgsqlBEMessage::ReadyForQuery(_) => "ready_for_query", PgsqlBEMessage::RowDescription(_) => "row_description", PgsqlBEMessage::SSLResponse(SSLResponseMessage::InvalidResponse) => { @@ -348,12 +369,6 @@ impl SASLAuthenticationMechanism { type SASLInitialResponse = (SASLAuthenticationMechanism, u32, Vec); -#[derive(Debug, PartialEq, Eq)] -pub struct TerminationMessage { - pub identifier: u8, - pub length: u32, -} - #[derive(Debug, PartialEq, Eq)] pub struct CancelRequestMessage { pub pid: u32, @@ -1017,6 +1032,47 @@ fn add_up_data_size(columns: Vec) -> u64 { data_size } +pub fn parse_copy_out_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { + let (i, identifier) = verify(be_u8, |&x| x == b'H')(i)?; + // copy out message : identifier (u8), length (u32), format (u8), cols (u16), formats (u16*cols) + let (i, length) = parse_gte_length(i, 8)?; + let (i, _format) = be_u8(i)?; + let (i, columns) = be_u16(i)?; + let (i, _formats) = many_m_n(0, columns.to_usize(), be_u16)(i)?; + Ok(( + i, + PgsqlBEMessage::CopyOutResponse(CopyOutResponse { + identifier, + length, + column_cnt: columns, + }) + )) +} + +pub fn parse_consolidated_copy_data_out(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { + let (i, identifier) = verify(be_u8, |&x| x == b'd')(i)?; + let (i, length) = parse_gte_length(i, 5)?; + let (i, _data) = take(length - PGSQL_LENGTH_FIELD)(i)?; + SCLogDebug!("data_size is {:?}", _data); + Ok(( + i, PgsqlBEMessage::ConsolidatedCopyDataOut(ConsolidatedDataRowPacket { + identifier, + row_cnt: 1, + data_size: (length - PGSQL_LENGTH_FIELD) as u64 }) + )) +} + +fn parse_copy_done(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> { + let (i, identifier) = verify(be_u8, |&x| x == b'c')(i)?; + let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?; + Ok(( + i, PgsqlBEMessage::CopyDone(TerminationMessage { + identifier, + length + }) + )) +} + // Currently, we don't store the actual DataRow messages, as those could easily become a burden, memory-wise // We use ConsolidatedDataRow to store info we still want to log: message size. // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages @@ -1211,10 +1267,13 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar b'R' => pgsql_parse_authentication_message(i)?, b'S' => parse_parameter_status_message(i)?, b'C' => parse_command_complete(i)?, + b'c' => parse_copy_done(i)?, b'Z' => parse_ready_for_query(i)?, b'T' => parse_row_description(i)?, b'A' => parse_notification_response(i)?, b'D' => parse_consolidated_data_row(i)?, + b'd' => parse_consolidated_copy_data_out(i)?, + b'H' => parse_copy_out_response(i)?, _ => { let (i, identifier) = be_u8(i)?; let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?; diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index b1b8f87c69..aca283a5f9 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -121,6 +121,9 @@ pub enum PgsqlStateProgress { CancelRequestReceived, ConnectionTerminated, // Related to Backend-received messages // + CopyOutResponseReceived, + CopyDataOutReceived, + CopyDoneReceived, SSLRejectedReceived, // SSPIAuthenticationReceived, // TODO implement SASLAuthenticationReceived, @@ -481,16 +484,24 @@ impl PgsqlState { } PgsqlBEMessage::ReadyForQuery(_) => Some(PgsqlStateProgress::ReadyForQueryReceived), // TODO should we store any Parameter Status in PgsqlState? + // TODO -- For CopyBoth mode, parameterstatus may be important (replication parameter) PgsqlBEMessage::AuthenticationMD5Password(_) | PgsqlBEMessage::AuthenticationCleartextPassword(_) => { Some(PgsqlStateProgress::SimpleAuthenticationReceived) } PgsqlBEMessage::RowDescription(_) => Some(PgsqlStateProgress::RowDescriptionReceived), + PgsqlBEMessage::CopyOutResponse(_) => Some(PgsqlStateProgress::CopyOutResponseReceived), PgsqlBEMessage::ConsolidatedDataRow(msg) => { // Increment tx.data_size here, since we know msg type, so that we can later on log that info self.transactions.back_mut()?.sum_data_size(msg.data_size); Some(PgsqlStateProgress::DataRowReceived) } + PgsqlBEMessage::ConsolidatedCopyDataOut(msg) => { + // Increment tx.data_size here, since we know msg type, so that we can later on log that info + self.transactions.back_mut()?.sum_data_size(msg.data_size); + Some(PgsqlStateProgress::CopyDataOutReceived) + } + PgsqlBEMessage::CopyDone(_) => Some(PgsqlStateProgress::CopyDoneReceived), PgsqlBEMessage::CommandComplete(_) => { // TODO Do we want to compare the command that was stored when // query was sent with what we received here? @@ -504,6 +515,7 @@ impl PgsqlState { PgsqlBEMessage::ErrorResponse(_) => Some(PgsqlStateProgress::ErrorMessageReceived), _ => { // We don't always have to change current state when we see a response... + // NotificationResponse and NoticeResponse fall here None } } @@ -582,6 +594,25 @@ impl PgsqlState { ); tx.responses.push(dummy_resp); tx.responses.push(response); + // reset values + tx.data_row_cnt = 0; + tx.data_size = 0; + } else if state == PgsqlStateProgress::CopyDataOutReceived { + tx.incr_row_cnt(); + } else if state == PgsqlStateProgress::CopyDoneReceived && tx.get_row_cnt() > 0 { + // let's summarize the info from the data_rows in one response + let dummy_resp = PgsqlBEMessage::ConsolidatedCopyDataOut( + ConsolidatedDataRowPacket { + identifier: b'd', + row_cnt: tx.get_row_cnt(), + data_size: tx.data_size, // total byte count of all data_row messages combined + }, + ); + tx.responses.push(dummy_resp); + tx.responses.push(response); + // reset values + tx.data_row_cnt = 0; + tx.data_size = 0; } else { tx.responses.push(response); if Self::response_is_complete(state) {