diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index eda9d48485..bd76346f80 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -22,11 +22,11 @@ use super::parser::{self, ConsolidatedDataRowPacket, PgsqlBEMessage, PgsqlFEMessage}; use crate::applayer::*; use crate::conf::*; -use crate::core::{AppProto, Flow, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP}; use nom7::{Err, IResult}; use std; use std::collections::VecDeque; use std::ffi::CString; +use crate::core::{Flow, AppProto, Direction, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP, *}; pub const PGSQL_CONFIG_DEFAULT_STREAM_DEPTH: u32 = 0; @@ -313,7 +313,7 @@ impl PgsqlState { } } - fn parse_request(&mut self, input: &[u8]) -> AppLayerResult { + fn parse_request(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty requests. if input.is_empty() { return AppLayerResult::ok(); @@ -341,6 +341,7 @@ impl PgsqlState { ); match PgsqlState::state_based_req_parsing(self.state_progress, start) { Ok((rem, request)) => { + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); start = rem; if let Some(state) = PgsqlState::request_next_state(&request) { self.state_progress = state; @@ -449,7 +450,7 @@ impl PgsqlState { } } - fn parse_response(&mut self, input: &[u8], flow: *const Flow) -> AppLayerResult { + fn parse_response(&mut self, flow: *const Flow, input: &[u8]) -> AppLayerResult { // We're not interested in empty responses. if input.is_empty() { return AppLayerResult::ok(); @@ -470,6 +471,7 @@ impl PgsqlState { while !start.is_empty() { match PgsqlState::state_based_resp_parsing(self.state_progress, start) { Ok((rem, response)) => { + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); start = rem; SCLogDebug!("Response is {:?}", &response); if let Some(state) = self.response_process_next_state(&response, flow) { @@ -633,7 +635,7 @@ pub extern "C" fn rs_pgsql_state_tx_free(state: *mut std::os::raw::c_void, tx_id #[no_mangle] pub unsafe extern "C" fn rs_pgsql_parse_request( - _flow: *const Flow, state: *mut std::os::raw::c_void, pstate: *mut std::os::raw::c_void, + flow: *const Flow, state: *mut std::os::raw::c_void, pstate: *mut std::os::raw::c_void, stream_slice: StreamSlice, _data: *const std::os::raw::c_void, ) -> AppLayerResult { if stream_slice.is_empty() { @@ -651,7 +653,7 @@ pub unsafe extern "C" fn rs_pgsql_parse_request( if stream_slice.is_gap() { state_safe.on_request_gap(stream_slice.gap_size()); } else if !stream_slice.is_empty() { - return state_safe.parse_request(stream_slice.as_slice()); + return state_safe.parse_request(flow, stream_slice.as_slice()); } AppLayerResult::ok() } @@ -674,7 +676,7 @@ pub unsafe extern "C" fn rs_pgsql_parse_response( if stream_slice.is_gap() { state_safe.on_response_gap(stream_slice.gap_size()); } else if !stream_slice.is_empty() { - return state_safe.parse_response(stream_slice.as_slice(), flow); + return state_safe.parse_response(flow, stream_slice.as_slice()); } AppLayerResult::ok() } @@ -835,7 +837,8 @@ mod test { let mut state = PgsqlState::new(); // an SSL Request let buf: &[u8] = &[0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f]; - state.parse_request(buf); + // We can pass null here as the only place that uses flow in the parse_request fn isn't run for unittests + state.parse_request(std::ptr::null_mut(), buf); let ok_state = PgsqlStateProgress::SSLRequestReceived; assert_eq!(state.state_progress, ok_state); @@ -849,7 +852,8 @@ mod test { // An SSL Request let buf: &[u8] = &[0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f]; - let r = state.parse_request(&buf[0..0]); + // We can pass null here as the only place that uses flow in the parse_request fn isn't run for unittests + let r = state.parse_request(std::ptr::null_mut(), &buf[0..0]); assert_eq!( r, AppLayerResult { @@ -859,7 +863,7 @@ mod test { } ); - let r = state.parse_request(&buf[0..1]); + let r = state.parse_request(std::ptr::null_mut(), &buf[0..1]); assert_eq!( r, AppLayerResult { @@ -869,7 +873,7 @@ mod test { } ); - let r = state.parse_request(&buf[0..2]); + let r = state.parse_request(std::ptr::null_mut(), &buf[0..2]); assert_eq!( r, AppLayerResult {