/* Copyright (C) 2023 Open Information Security Foundation * * You can copy, redistribute or modify this Program under the terms of * the GNU General Public License version 2 as published by the Free * Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * version 2 along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA * 02110-1301, USA. */ use super::parser; use crate::applayer::{self, *}; use crate::conf::conf_get; use crate::core::{ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP}; use crate::direction::Direction; use crate::flow::Flow; use crate::frames::Frame; use nom7 as nom; use nom7::Needed; use flate2::read::DeflateDecoder; use suricata_sys::sys::AppProto; use std; use std::collections::VecDeque; use std::ffi::CString; use std::io::Read; use std::os::raw::{c_char, c_int, c_void}; pub(super) static mut ALPROTO_WEBSOCKET: AppProto = ALPROTO_UNKNOWN; static mut WEBSOCKET_MAX_PAYLOAD_SIZE: u32 = 0xFFFF; #[derive(AppLayerFrameType)] pub enum WebSocketFrameType { Header, Pdu, Data, } #[derive(AppLayerEvent)] pub enum WebSocketEvent { SkipEndOfPayload, ReassemblyLimitReached, } #[derive(Default)] pub struct WebSocketTransaction { tx_id: u64, pub pdu: parser::WebSocketPdu, tx_data: AppLayerTxData, } impl WebSocketTransaction { pub fn new(direction: Direction) -> WebSocketTransaction { Self { tx_data: AppLayerTxData::for_direction(direction), ..Default::default() } } } impl Transaction for WebSocketTransaction { fn id(&self) -> u64 { self.tx_id } } #[derive(Default)] struct WebSocketReassemblyBuffer { data: Vec, compress: bool, } #[derive(Default)] pub struct WebSocketState { state_data: AppLayerStateData, tx_id: u64, transactions: VecDeque, c2s_buf: WebSocketReassemblyBuffer, s2c_buf: WebSocketReassemblyBuffer, to_skip_tc: u64, to_skip_ts: u64, } impl State for WebSocketState { fn get_transaction_count(&self) -> usize { self.transactions.len() } fn get_transaction_by_index(&self, index: usize) -> Option<&WebSocketTransaction> { self.transactions.get(index) } } impl WebSocketState { pub fn new() -> Self { Default::default() } // Free a transaction by ID. fn free_tx(&mut self, tx_id: u64) { let len = self.transactions.len(); let mut found = false; let mut index = 0; for i in 0..len { let tx = &self.transactions[i]; if tx.tx_id == tx_id + 1 { found = true; index = i; break; } } if found { self.transactions.remove(index); } } pub fn get_tx(&mut self, tx_id: u64) -> Option<&WebSocketTransaction> { self.transactions.iter().find(|tx| tx.tx_id == tx_id + 1) } fn new_tx(&mut self, direction: Direction) -> WebSocketTransaction { let mut tx = WebSocketTransaction::new(direction); self.tx_id += 1; tx.tx_id = self.tx_id; return tx; } fn parse( &mut self, stream_slice: StreamSlice, direction: Direction, flow: *const Flow, ) -> AppLayerResult { let to_skip = if direction == Direction::ToClient { &mut self.to_skip_tc } else { &mut self.to_skip_ts }; let input = stream_slice.as_slice(); let mut start = input; if *to_skip > 0 { if *to_skip >= input.len() as u64 { *to_skip -= input.len() as u64; return AppLayerResult::ok(); } else { start = &input[*to_skip as usize..]; *to_skip = 0; } } let max_pl_size = unsafe { WEBSOCKET_MAX_PAYLOAD_SIZE }; while !start.is_empty() { match parser::parse_message(start, max_pl_size) { Ok((rem, pdu)) => { let mut tx = self.new_tx(direction); let _pdu = Frame::new( flow, &stream_slice, start, (start.len() - rem.len() - pdu.payload.len()) as i64, WebSocketFrameType::Header as u8, Some(tx.tx_id), ); let _pdu = Frame::new( flow, &stream_slice, start, (start.len() - rem.len()) as i64, WebSocketFrameType::Pdu as u8, Some(tx.tx_id), ); let _pdu = Frame::new( flow, &stream_slice, &start[(start.len() - rem.len() - pdu.payload.len())..], pdu.payload.len() as i64, WebSocketFrameType::Data as u8, Some(tx.tx_id), ); start = rem; if pdu.to_skip > 0 { if direction == Direction::ToClient { self.to_skip_tc = pdu.to_skip; } else { self.to_skip_ts = pdu.to_skip; } tx.tx_data.set_event(WebSocketEvent::SkipEndOfPayload as u8); } let buf = if direction == Direction::ToClient { &mut self.s2c_buf } else { &mut self.c2s_buf }; if !buf.data.is_empty() || !pdu.fin { if buf.data.is_empty() { buf.compress = pdu.compress; } if buf.data.len() + pdu.payload.len() < max_pl_size as usize { buf.data.extend(&pdu.payload); } else if buf.data.len() < max_pl_size as usize { buf.data .extend(&pdu.payload[..max_pl_size as usize - buf.data.len()]); tx.tx_data .set_event(WebSocketEvent::ReassemblyLimitReached as u8); } } tx.pdu = pdu; if tx.pdu.fin && !buf.data.is_empty() { // the final PDU gets the full reassembled payload std::mem::swap(&mut tx.pdu.payload, &mut buf.data); buf.data.clear(); } if buf.compress && tx.pdu.fin { buf.compress = false; // cf RFC 7692 section-7.2.2 tx.pdu.payload.extend_from_slice(&[0, 0, 0xFF, 0xFF]); let mut deflater = DeflateDecoder::new(&tx.pdu.payload[..]); let mut v = Vec::new(); // do not check result because // deflate with rust backend fails on good input cf https://github.com/rust-lang/flate2-rs/issues/389 let _ = deflater.read_to_end(&mut v); if !v.is_empty() { std::mem::swap(&mut tx.pdu.payload, &mut v); } } self.transactions.push_back(tx); } Err(nom::Err::Incomplete(needed)) => { if let Needed::Size(n) = needed { let n = usize::from(n); // Not enough data. just ask for one more byte. let consumed = input.len() - start.len(); let needed = start.len() + n; return AppLayerResult::incomplete(consumed as u32, needed as u32); } return AppLayerResult::err(); } Err(_) => { return AppLayerResult::err(); } } } // Input was fully consumed. return AppLayerResult::ok(); } } // C exports. #[no_mangle] pub unsafe extern "C" fn rs_websocket_probing_parser( _flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8, ) -> AppProto { if !input.is_null() { let slice = build_slice!(input, input_len as usize); if !slice.is_empty() { // just check reserved bits are zeroed, except RSV1 // as RSV1 is used for compression cf RFC 7692 if slice[0] & 0x30 == 0 { return ALPROTO_WEBSOCKET; } return ALPROTO_FAILED; } } return ALPROTO_UNKNOWN; } extern "C" fn rs_websocket_state_new( _orig_state: *mut c_void, _orig_proto: AppProto, ) -> *mut c_void { let state = WebSocketState::new(); let boxed = Box::new(state); return Box::into_raw(boxed) as *mut c_void; } unsafe extern "C" fn rs_websocket_state_free(state: *mut c_void) { std::mem::drop(Box::from_raw(state as *mut WebSocketState)); } unsafe extern "C" fn rs_websocket_state_tx_free(state: *mut c_void, tx_id: u64) { let state = cast_pointer!(state, WebSocketState); state.free_tx(tx_id); } unsafe extern "C" fn rs_websocket_parse_request( flow: *const Flow, state: *mut c_void, _pstate: *mut c_void, stream_slice: StreamSlice, _data: *const c_void, ) -> AppLayerResult { let state = cast_pointer!(state, WebSocketState); state.parse(stream_slice, Direction::ToServer, flow) } unsafe extern "C" fn rs_websocket_parse_response( flow: *const Flow, state: *mut c_void, _pstate: *mut c_void, stream_slice: StreamSlice, _data: *const c_void, ) -> AppLayerResult { let state = cast_pointer!(state, WebSocketState); state.parse(stream_slice, Direction::ToClient, flow) } unsafe extern "C" fn rs_websocket_state_get_tx(state: *mut c_void, tx_id: u64) -> *mut c_void { let state = cast_pointer!(state, WebSocketState); match state.get_tx(tx_id) { Some(tx) => { return tx as *const _ as *mut _; } None => { return std::ptr::null_mut(); } } } unsafe extern "C" fn rs_websocket_state_get_tx_count(state: *mut c_void) -> u64 { let state = cast_pointer!(state, WebSocketState); return state.tx_id; } unsafe extern "C" fn rs_websocket_tx_get_alstate_progress( _tx: *mut c_void, _direction: u8, ) -> c_int { return 1; } export_tx_data_get!(websocket_get_tx_data, WebSocketTransaction); export_state_data_get!(websocket_get_state_data, WebSocketState); // Parser name as a C style string. const PARSER_NAME: &[u8] = b"websocket\0"; #[no_mangle] pub unsafe extern "C" fn rs_websocket_register_parser() { let parser = RustParser { name: PARSER_NAME.as_ptr() as *const c_char, default_port: std::ptr::null(), ipproto: IPPROTO_TCP, probe_ts: Some(rs_websocket_probing_parser), probe_tc: Some(rs_websocket_probing_parser), min_depth: 0, max_depth: 16, state_new: rs_websocket_state_new, state_free: rs_websocket_state_free, tx_free: rs_websocket_state_tx_free, parse_ts: rs_websocket_parse_request, parse_tc: rs_websocket_parse_response, get_tx_count: rs_websocket_state_get_tx_count, get_tx: rs_websocket_state_get_tx, tx_comp_st_ts: 1, tx_comp_st_tc: 1, tx_get_progress: rs_websocket_tx_get_alstate_progress, get_eventinfo: Some(WebSocketEvent::get_event_info), get_eventinfo_byid: Some(WebSocketEvent::get_event_info_by_id), localstorage_new: None, localstorage_free: None, get_tx_files: None, get_tx_iterator: Some( applayer::state_get_tx_iterator::, ), get_tx_data: websocket_get_tx_data, get_state_data: websocket_get_state_data, apply_tx_config: None, flags: 0, // do not accept gaps as there is no good way to resync get_frame_id_by_name: Some(WebSocketFrameType::ffi_id_from_name), get_frame_name_by_id: Some(WebSocketFrameType::ffi_name_from_id), get_state_id_by_name: None, get_state_name_by_id: None, }; let ip_proto_str = CString::new("tcp").unwrap(); if AppLayerProtoDetectConfProtoDetectionEnabled(ip_proto_str.as_ptr(), parser.name) != 0 { let alproto = AppLayerRegisterProtocolDetection(&parser, 1); ALPROTO_WEBSOCKET = alproto; if AppLayerParserConfParserEnabled(ip_proto_str.as_ptr(), parser.name) != 0 { let _ = AppLayerRegisterParser(&parser, alproto); } SCLogDebug!("Rust websocket parser registered."); if let Some(val) = conf_get("app-layer.protocols.websocket.max-payload-size") { if let Ok(v) = val.parse::() { WEBSOCKET_MAX_PAYLOAD_SIZE = v; } else { SCLogError!("Invalid value for websocket.max-payload-size"); } } AppLayerParserRegisterLogger(IPPROTO_TCP, ALPROTO_WEBSOCKET); } else { SCLogDebug!("Protocol detector and parser disabled for WEBSOCKET."); } }