You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
suricata/rust/src/websocket/websocket.rs

398 lines
14 KiB
Rust

/* Copyright (C) 2023 Open Information Security Foundation
*
* You can copy, redistribute or modify this Program under the terms of
* the GNU General Public License version 2 as published by the Free
* Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* version 2 along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
* 02110-1301, USA.
*/
use super::parser;
use crate::applayer::{self, *};
use crate::conf::conf_get;
use crate::core::{ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP};
use crate::direction::Direction;
use crate::flow::Flow;
use crate::frames::Frame;
use nom7 as nom;
use nom7::Needed;
use flate2::read::DeflateDecoder;
use suricata_sys::sys::AppProto;
use std;
use std::collections::VecDeque;
use std::ffi::CString;
use std::io::Read;
use std::os::raw::{c_char, c_int, c_void};
pub(super) static mut ALPROTO_WEBSOCKET: AppProto = ALPROTO_UNKNOWN;
static mut WEBSOCKET_MAX_PAYLOAD_SIZE: u32 = 0xFFFF;
#[derive(AppLayerFrameType)]
pub enum WebSocketFrameType {
Header,
Pdu,
Data,
}
#[derive(AppLayerEvent)]
pub enum WebSocketEvent {
SkipEndOfPayload,
ReassemblyLimitReached,
}
#[derive(Default)]
pub struct WebSocketTransaction {
tx_id: u64,
pub pdu: parser::WebSocketPdu,
tx_data: AppLayerTxData,
}
impl WebSocketTransaction {
pub fn new(direction: Direction) -> WebSocketTransaction {
Self {
tx_data: AppLayerTxData::for_direction(direction),
..Default::default()
}
}
}
impl Transaction for WebSocketTransaction {
fn id(&self) -> u64 {
self.tx_id
}
}
#[derive(Default)]
struct WebSocketReassemblyBuffer {
data: Vec<u8>,
compress: bool,
}
#[derive(Default)]
pub struct WebSocketState {
state_data: AppLayerStateData,
tx_id: u64,
transactions: VecDeque<WebSocketTransaction>,
c2s_buf: WebSocketReassemblyBuffer,
s2c_buf: WebSocketReassemblyBuffer,
to_skip_tc: u64,
to_skip_ts: u64,
}
impl State<WebSocketTransaction> for WebSocketState {
fn get_transaction_count(&self) -> usize {
self.transactions.len()
}
fn get_transaction_by_index(&self, index: usize) -> Option<&WebSocketTransaction> {
self.transactions.get(index)
}
}
impl WebSocketState {
pub fn new() -> Self {
Default::default()
}
// Free a transaction by ID.
fn free_tx(&mut self, tx_id: u64) {
let len = self.transactions.len();
let mut found = false;
let mut index = 0;
for i in 0..len {
let tx = &self.transactions[i];
if tx.tx_id == tx_id + 1 {
found = true;
index = i;
break;
}
}
if found {
self.transactions.remove(index);
}
}
pub fn get_tx(&mut self, tx_id: u64) -> Option<&WebSocketTransaction> {
self.transactions.iter().find(|tx| tx.tx_id == tx_id + 1)
}
fn new_tx(&mut self, direction: Direction) -> WebSocketTransaction {
let mut tx = WebSocketTransaction::new(direction);
self.tx_id += 1;
tx.tx_id = self.tx_id;
return tx;
}
fn parse(
&mut self, stream_slice: StreamSlice, direction: Direction, flow: *const Flow,
) -> AppLayerResult {
let to_skip = if direction == Direction::ToClient {
&mut self.to_skip_tc
} else {
&mut self.to_skip_ts
};
let input = stream_slice.as_slice();
let mut start = input;
if *to_skip > 0 {
if *to_skip >= input.len() as u64 {
*to_skip -= input.len() as u64;
return AppLayerResult::ok();
} else {
start = &input[*to_skip as usize..];
*to_skip = 0;
}
}
let max_pl_size = unsafe { WEBSOCKET_MAX_PAYLOAD_SIZE };
while !start.is_empty() {
match parser::parse_message(start, max_pl_size) {
Ok((rem, pdu)) => {
let mut tx = self.new_tx(direction);
let _pdu = Frame::new(
flow,
&stream_slice,
start,
(start.len() - rem.len() - pdu.payload.len()) as i64,
WebSocketFrameType::Header as u8,
Some(tx.tx_id),
);
let _pdu = Frame::new(
flow,
&stream_slice,
start,
(start.len() - rem.len()) as i64,
WebSocketFrameType::Pdu as u8,
Some(tx.tx_id),
);
let _pdu = Frame::new(
flow,
&stream_slice,
&start[(start.len() - rem.len() - pdu.payload.len())..],
pdu.payload.len() as i64,
WebSocketFrameType::Data as u8,
Some(tx.tx_id),
);
start = rem;
if pdu.to_skip > 0 {
if direction == Direction::ToClient {
self.to_skip_tc = pdu.to_skip;
} else {
self.to_skip_ts = pdu.to_skip;
}
tx.tx_data.set_event(WebSocketEvent::SkipEndOfPayload as u8);
}
let buf = if direction == Direction::ToClient {
&mut self.s2c_buf
} else {
&mut self.c2s_buf
};
if !buf.data.is_empty() || !pdu.fin {
if buf.data.is_empty() {
buf.compress = pdu.compress;
}
if buf.data.len() + pdu.payload.len() < max_pl_size as usize {
buf.data.extend(&pdu.payload);
} else if buf.data.len() < max_pl_size as usize {
buf.data
.extend(&pdu.payload[..max_pl_size as usize - buf.data.len()]);
tx.tx_data
.set_event(WebSocketEvent::ReassemblyLimitReached as u8);
}
}
tx.pdu = pdu;
if tx.pdu.fin && !buf.data.is_empty() {
// the final PDU gets the full reassembled payload
std::mem::swap(&mut tx.pdu.payload, &mut buf.data);
buf.data.clear();
}
if buf.compress && tx.pdu.fin {
buf.compress = false;
// cf RFC 7692 section-7.2.2
tx.pdu.payload.extend_from_slice(&[0, 0, 0xFF, 0xFF]);
let mut deflater = DeflateDecoder::new(&tx.pdu.payload[..]);
let mut v = Vec::new();
// do not check result because
// deflate with rust backend fails on good input cf https://github.com/rust-lang/flate2-rs/issues/389
let _ = deflater.read_to_end(&mut v);
if !v.is_empty() {
std::mem::swap(&mut tx.pdu.payload, &mut v);
}
}
self.transactions.push_back(tx);
}
Err(nom::Err::Incomplete(needed)) => {
if let Needed::Size(n) = needed {
let n = usize::from(n);
// Not enough data. just ask for one more byte.
let consumed = input.len() - start.len();
let needed = start.len() + n;
return AppLayerResult::incomplete(consumed as u32, needed as u32);
}
return AppLayerResult::err();
}
Err(_) => {
return AppLayerResult::err();
}
}
}
// Input was fully consumed.
return AppLayerResult::ok();
}
}
// C exports.
#[no_mangle]
pub unsafe extern "C" fn rs_websocket_probing_parser(
_flow: *const Flow, _direction: u8, input: *const u8, input_len: u32, _rdir: *mut u8,
) -> AppProto {
if !input.is_null() {
let slice = build_slice!(input, input_len as usize);
if !slice.is_empty() {
// just check reserved bits are zeroed, except RSV1
// as RSV1 is used for compression cf RFC 7692
if slice[0] & 0x30 == 0 {
return ALPROTO_WEBSOCKET;
}
return ALPROTO_FAILED;
}
}
return ALPROTO_UNKNOWN;
}
extern "C" fn rs_websocket_state_new(
_orig_state: *mut c_void, _orig_proto: AppProto,
) -> *mut c_void {
let state = WebSocketState::new();
let boxed = Box::new(state);
return Box::into_raw(boxed) as *mut c_void;
}
unsafe extern "C" fn rs_websocket_state_free(state: *mut c_void) {
std::mem::drop(Box::from_raw(state as *mut WebSocketState));
}
unsafe extern "C" fn rs_websocket_state_tx_free(state: *mut c_void, tx_id: u64) {
let state = cast_pointer!(state, WebSocketState);
state.free_tx(tx_id);
}
unsafe extern "C" fn rs_websocket_parse_request(
flow: *const Flow, state: *mut c_void, _pstate: *mut c_void, stream_slice: StreamSlice,
_data: *const c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, WebSocketState);
state.parse(stream_slice, Direction::ToServer, flow)
}
unsafe extern "C" fn rs_websocket_parse_response(
flow: *const Flow, state: *mut c_void, _pstate: *mut c_void, stream_slice: StreamSlice,
_data: *const c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, WebSocketState);
state.parse(stream_slice, Direction::ToClient, flow)
}
unsafe extern "C" fn rs_websocket_state_get_tx(state: *mut c_void, tx_id: u64) -> *mut c_void {
let state = cast_pointer!(state, WebSocketState);
match state.get_tx(tx_id) {
Some(tx) => {
return tx as *const _ as *mut _;
}
None => {
return std::ptr::null_mut();
}
}
}
unsafe extern "C" fn rs_websocket_state_get_tx_count(state: *mut c_void) -> u64 {
let state = cast_pointer!(state, WebSocketState);
return state.tx_id;
}
unsafe extern "C" fn rs_websocket_tx_get_alstate_progress(
_tx: *mut c_void, _direction: u8,
) -> c_int {
return 1;
}
export_tx_data_get!(websocket_get_tx_data, WebSocketTransaction);
export_state_data_get!(websocket_get_state_data, WebSocketState);
// Parser name as a C style string.
const PARSER_NAME: &[u8] = b"websocket\0";
#[no_mangle]
pub unsafe extern "C" fn rs_websocket_register_parser() {
let parser = RustParser {
name: PARSER_NAME.as_ptr() as *const c_char,
default_port: std::ptr::null(),
ipproto: IPPROTO_TCP,
probe_ts: Some(rs_websocket_probing_parser),
probe_tc: Some(rs_websocket_probing_parser),
min_depth: 0,
max_depth: 16,
state_new: rs_websocket_state_new,
state_free: rs_websocket_state_free,
tx_free: rs_websocket_state_tx_free,
parse_ts: rs_websocket_parse_request,
parse_tc: rs_websocket_parse_response,
get_tx_count: rs_websocket_state_get_tx_count,
get_tx: rs_websocket_state_get_tx,
tx_comp_st_ts: 1,
tx_comp_st_tc: 1,
tx_get_progress: rs_websocket_tx_get_alstate_progress,
get_eventinfo: Some(WebSocketEvent::get_event_info),
get_eventinfo_byid: Some(WebSocketEvent::get_event_info_by_id),
localstorage_new: None,
localstorage_free: None,
get_tx_files: None,
get_tx_iterator: Some(
applayer::state_get_tx_iterator::<WebSocketState, WebSocketTransaction>,
),
get_tx_data: websocket_get_tx_data,
get_state_data: websocket_get_state_data,
apply_tx_config: None,
flags: 0, // do not accept gaps as there is no good way to resync
get_frame_id_by_name: Some(WebSocketFrameType::ffi_id_from_name),
get_frame_name_by_id: Some(WebSocketFrameType::ffi_name_from_id),
get_state_id_by_name: None,
get_state_name_by_id: None,
};
let ip_proto_str = CString::new("tcp").unwrap();
if AppLayerProtoDetectConfProtoDetectionEnabled(ip_proto_str.as_ptr(), parser.name) != 0 {
let alproto = AppLayerRegisterProtocolDetection(&parser, 1);
ALPROTO_WEBSOCKET = alproto;
if AppLayerParserConfParserEnabled(ip_proto_str.as_ptr(), parser.name) != 0 {
let _ = AppLayerRegisterParser(&parser, alproto);
}
SCLogDebug!("Rust websocket parser registered.");
if let Some(val) = conf_get("app-layer.protocols.websocket.max-payload-size") {
if let Ok(v) = val.parse::<u32>() {
WEBSOCKET_MAX_PAYLOAD_SIZE = v;
} else {
SCLogError!("Invalid value for websocket.max-payload-size");
}
}
AppLayerParserRegisterLogger(IPPROTO_TCP, ALPROTO_WEBSOCKET);
} else {
SCLogDebug!("Protocol detector and parser disabled for WEBSOCKET.");
}
}