diff --git a/rust/src/detect.rs b/rust/src/detect.rs index 27c5367172..6bc0ed0f3b 100644 --- a/rust/src/detect.rs +++ b/rust/src/detect.rs @@ -17,13 +17,14 @@ use nom7::branch::alt; use nom7::bytes::complete::{is_a, tag, take_while}; -use nom7::character::complete::digit1; -use nom7::combinator::{all_consuming, map_opt, opt, value, verify}; +use nom7::character::complete::{alpha0, char, digit1}; +use nom7::combinator::{all_consuming, map_opt, map_res, opt, value, verify}; use nom7::error::{make_error, ErrorKind}; use nom7::Err; use nom7::IResult; use std::ffi::CStr; +use std::str::FromStr; #[derive(PartialEq, Clone, Debug)] #[repr(u8)] @@ -92,7 +93,9 @@ fn detect_parse_uint_mode(i: &str) -> IResult<&str, DetectUintMode> { value(DetectUintMode::DetectUintModeLte, tag("<=")), value(DetectUintMode::DetectUintModeGt, tag(">")), value(DetectUintMode::DetectUintModeLt, tag("<")), + value(DetectUintMode::DetectUintModeNe, tag("!=")), value(DetectUintMode::DetectUintModeNe, tag("!")), + value(DetectUintMode::DetectUintModeEqual, tag("=")), ))(i)?; return Ok((i, mode)); } @@ -310,3 +313,78 @@ pub unsafe extern "C" fn rs_detect_u16_free(ctx: &mut DetectUintData) { // Just unbox... std::mem::drop(Box::from_raw(ctx)); } + +#[repr(u8)] +#[derive(Clone, Copy, PartialEq, FromPrimitive, Debug)] +pub enum DetectStreamSizeDataFlags { + StreamSizeServer = 1, + StreamSizeClient = 2, + StreamSizeBoth = 3, + StreamSizeEither = 4, +} + +impl std::str::FromStr for DetectStreamSizeDataFlags { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "server" => Ok(DetectStreamSizeDataFlags::StreamSizeServer), + "client" => Ok(DetectStreamSizeDataFlags::StreamSizeClient), + "both" => Ok(DetectStreamSizeDataFlags::StreamSizeBoth), + "either" => Ok(DetectStreamSizeDataFlags::StreamSizeEither), + _ => Err(format!( + "'{}' is not a valid value for DetectStreamSizeDataFlags", + s + )), + } + } +} + +#[derive(Debug)] +#[repr(C)] +pub struct DetectStreamSizeData { + pub flags: DetectStreamSizeDataFlags, + pub du32: DetectUintData, +} + +pub fn detect_parse_stream_size(i: &str) -> IResult<&str, DetectStreamSizeData> { + let (i, _) = opt(is_a(" "))(i)?; + let (i, flags) = map_res(alpha0, |s: &str| { + DetectStreamSizeDataFlags::from_str(s) + })(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, mode) = detect_parse_uint_mode(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, _) = char(',')(i)?; + let (i, _) = opt(is_a(" "))(i)?; + let (i, arg1) = map_opt(digit1, |s: &str| s.parse::().ok())(i)?; + let (i, _) = all_consuming(take_while(|c| c == ' '))(i)?; + let du32 = DetectUintData:: { + arg1: arg1, + arg2: 0, + mode: mode, + }; + Ok((i, DetectStreamSizeData { flags, du32 })) +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_stream_size_parse( + ustr: *const std::os::raw::c_char, +) -> *mut DetectStreamSizeData { + let ft_name: &CStr = CStr::from_ptr(ustr); //unsafe + if let Ok(s) = ft_name.to_str() { + if let Ok((_, ctx)) = detect_parse_stream_size(s) { + let boxed = Box::new(ctx); + return Box::into_raw(boxed) as *mut _; + } + } + return std::ptr::null_mut(); +} + +#[no_mangle] +pub unsafe extern "C" fn rs_detect_stream_size_free(ctx: &mut DetectStreamSizeData) { + // Just unbox... + std::mem::drop(Box::from_raw(ctx)); +} diff --git a/src/detect-stream_size.c b/src/detect-stream_size.c index 6897325622..86aabd77c8 100644 --- a/src/detect-stream_size.c +++ b/src/detect-stream_size.c @@ -33,15 +33,11 @@ #include "flow.h" #include "detect-stream_size.h" #include "stream-tcp-private.h" +#include "detect-engine-prefilter-common.h" +#include "detect-engine-uint.h" #include "util-debug.h" #include "util-byte.h" -/** - * \brief Regex for parsing our flow options - */ -#define PARSE_REGEX "^\\s*([A-z_]+)\\s*,\\s*([<=>!]+)\\s*,\\s*([0-9]+)\\s*$" - -static DetectParseRegex parse_regex; /*prototypes*/ static int DetectStreamSizeMatch (DetectEngineThreadCtx *, Packet *, @@ -51,6 +47,8 @@ void DetectStreamSizeFree(DetectEngineCtx *de_ctx, void *); #ifdef UNITTESTS static void DetectStreamSizeRegisterTests(void); #endif +static int PrefilterSetupStreamSize(DetectEngineCtx *de_ctx, SigGroupHead *sgh); +static bool PrefilterStreamSizeIsPrefilterable(const Signature *s); /** * \brief Registration function for stream_size: keyword @@ -67,217 +65,69 @@ void DetectStreamSizeRegister(void) #ifdef UNITTESTS sigmatch_table[DETECT_STREAM_SIZE].RegisterTests = DetectStreamSizeRegisterTests; #endif - DetectSetupParseRegexes(PARSE_REGEX, &parse_regex); + sigmatch_table[DETECT_STREAM_SIZE].SupportsPrefilter = PrefilterStreamSizeIsPrefilterable; + sigmatch_table[DETECT_STREAM_SIZE].SetupPrefilter = PrefilterSetupStreamSize; } -/** - * \brief Function to comapre the stream size against defined size in the user - * options. - * - * \param diff The stream size of server or client stream. - * \param stream_size User defined stream size - * \param mode The mode defined by user. - * - * \retval 1 on success and 0 on failure. - */ - -static int DetectStreamSizeCompare (uint32_t diff, uint32_t stream_size, uint8_t mode) +static int DetectStreamSizeMatchAux(const DetectStreamSizeData *sd, const TcpSession *ssn) { - SCLogDebug("diff %u stream_size %u mode %u", diff, stream_size, mode); - - int ret = 0; - switch (mode) { - case DETECTSSIZE_LT: - if (diff < stream_size) - ret = 1; - break; - case DETECTSSIZE_LEQ: - if (diff <= stream_size) - ret = 1; - break; - case DETECTSSIZE_EQ: - if (diff == stream_size) - ret = 1; - break; - case DETECTSSIZE_NEQ: - if (diff != stream_size) - ret = 1; - break; - case DETECTSSIZE_GEQ: - if (diff >= stream_size) - ret = 1; - break; - case DETECTSSIZE_GT: - if (diff > stream_size) - ret = 1; - break; - } - - SCReturnInt(ret); -} - -/** - * \brief This function is used to match Stream size rule option on a packet with those passed via stream_size: - * - * \param t pointer to thread vars - * \param det_ctx pointer to the pattern matcher thread - * \param p pointer to the current packet - * \param m pointer to the sigmatch that we will cast into DetectStreamSizeData - * - * \retval 0 no match - * \retval 1 match - */ -static int DetectStreamSizeMatch (DetectEngineThreadCtx *det_ctx, Packet *p, - const Signature *s, const SigMatchCtx *ctx) -{ - - const DetectStreamSizeData *sd = (const DetectStreamSizeData *)ctx; - - if (!(PKT_IS_TCP(p))) - return 0; - if (p->flow == NULL || p->flow->protoctx == NULL) - return 0; - - const TcpSession *ssn = (TcpSession *)p->flow->protoctx; int ret = 0; uint32_t csdiff = 0; uint32_t ssdiff = 0; - if (sd->flags & STREAM_SIZE_SERVER) { + if (sd->flags == StreamSizeServer) { /* get the server stream size */ ssdiff = ssn->server.next_seq - ssn->server.isn; - ret = DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode); + ret = DetectU32Match(ssdiff, &sd->du32); - } else if (sd->flags & STREAM_SIZE_CLIENT) { + } else if (sd->flags == StreamSizeClient) { /* get the client stream size */ csdiff = ssn->client.next_seq - ssn->client.isn; - ret = DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode); + ret = DetectU32Match(csdiff, &sd->du32); - } else if (sd->flags & STREAM_SIZE_BOTH) { + } else if (sd->flags == StreamSizeBoth) { ssdiff = ssn->server.next_seq - ssn->server.isn; csdiff = ssn->client.next_seq - ssn->client.isn; - if (DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode) && - DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode)) + if (DetectU32Match(ssdiff, &sd->du32) && DetectU32Match(csdiff, &sd->du32)) ret = 1; - } else if (sd->flags & STREAM_SIZE_EITHER) { + } else if (sd->flags == StreamSizeEither) { ssdiff = ssn->server.next_seq - ssn->server.isn; csdiff = ssn->client.next_seq - ssn->client.isn; - if (DetectStreamSizeCompare(ssdiff, sd->ssize, sd->mode) || - DetectStreamSizeCompare(csdiff, sd->ssize, sd->mode)) + if (DetectU32Match(ssdiff, &sd->du32) || DetectU32Match(csdiff, &sd->du32)) ret = 1; } - - SCReturnInt(ret); + return ret; } /** - * \brief This function is used to parse stream options passed via stream_size: keyword + * \brief This function is used to match Stream size rule option on a packet with those passed via + * stream_size: * - * \param de_ctx Pointer to the detection engine context - * \param streamstr Pointer to the user provided stream_size options + * \param t pointer to thread vars + * \param det_ctx pointer to the pattern matcher thread + * \param p pointer to the current packet + * \param m pointer to the sigmatch that we will cast into DetectStreamSizeData * - * \retval sd pointer to DetectStreamSizeData on success - * \retval NULL on failure + * \retval 0 no match + * \retval 1 match */ -static DetectStreamSizeData *DetectStreamSizeParse (DetectEngineCtx *de_ctx, const char *streamstr) +static int DetectStreamSizeMatch( + DetectEngineThreadCtx *det_ctx, Packet *p, const Signature *s, const SigMatchCtx *ctx) { - DetectStreamSizeData *sd = NULL; - char *arg = NULL; - char *value = NULL; - char *mode = NULL; - int res = 0; - size_t pcre2_len; - - int ret = DetectParsePcreExec(&parse_regex, streamstr, 0, 0); - if (ret != 4) { - SCLogError(SC_ERR_PCRE_MATCH, "pcre_exec parse error, ret %" PRId32 ", string %s", ret, streamstr); - goto error; - } - const char *str_ptr; - res = pcre2_substring_get_bynumber(parse_regex.match, 1, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed"); - goto error; - } - arg = (char *)str_ptr; - - res = pcre2_substring_get_bynumber(parse_regex.match, 2, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed"); - goto error; - } - mode = (char *)str_ptr; + const DetectStreamSizeData *sd = (const DetectStreamSizeData *)ctx; - res = pcre2_substring_get_bynumber(parse_regex.match, 3, (PCRE2_UCHAR8 **)&str_ptr, &pcre2_len); - if (res < 0) { - SCLogError(SC_ERR_PCRE_GET_SUBSTRING, "pcre2_substring_get_bynumber failed"); - goto error; - } - value = (char *)str_ptr; - - sd = SCMalloc(sizeof(DetectStreamSizeData)); - if (unlikely(sd == NULL)) - goto error; - sd->ssize = 0; - sd->flags = 0; - - if (strlen(mode) == 0) - goto error; - - if (mode[0] == '=') { - sd->mode = DETECTSSIZE_EQ; - } else if (mode[0] == '<') { - sd->mode = DETECTSSIZE_LT; - if (strcmp("<=", mode) == 0) - sd->mode = DETECTSSIZE_LEQ; - } else if (mode[0] == '>') { - sd->mode = DETECTSSIZE_GT; - if (strcmp(">=", mode) == 0) - sd->mode = DETECTSSIZE_GEQ; - } else if (strcmp("!=", mode) == 0) { - sd->mode = DETECTSSIZE_NEQ; - } else { - SCLogError(SC_ERR_INVALID_OPERATOR, "Invalid operator"); - goto error; - } + if (!(PKT_IS_TCP(p))) + return 0; + if (p->flow == NULL || p->flow->protoctx == NULL) + return 0; - /* set the value */ - if (StringParseUint32(&sd->ssize, 10, 0, (const char *)value) < 0) { - SCLogError(SC_ERR_INVALID_VALUE, "Invalid value for stream size: %s", value); - goto error; - } - /* inspect our options and set the flags */ - if (strcmp(arg, "server") == 0) { - sd->flags |= STREAM_SIZE_SERVER; - } else if (strcmp(arg, "client") == 0) { - sd->flags |= STREAM_SIZE_CLIENT; - } else if ((strcmp(arg, "both") == 0)) { - sd->flags |= STREAM_SIZE_BOTH; - } else if (strcmp(arg, "either") == 0) { - sd->flags |= STREAM_SIZE_EITHER; - } else { - goto error; - } + const TcpSession *ssn = (TcpSession *)p->flow->protoctx; - pcre2_substring_free((PCRE2_UCHAR8 *)mode); - pcre2_substring_free((PCRE2_UCHAR8 *)arg); - pcre2_substring_free((PCRE2_UCHAR8 *)value); - return sd; - -error: - if (mode != NULL) - pcre2_substring_free((PCRE2_UCHAR8 *)mode); - if (arg != NULL) - pcre2_substring_free((PCRE2_UCHAR8 *)arg); - if (value != NULL) - pcre2_substring_free((PCRE2_UCHAR8 *)value); - if (sd != NULL) - DetectStreamSizeFree(de_ctx, sd); - return NULL; + SCReturnInt(DetectStreamSizeMatchAux(sd, ssn)); } /** @@ -292,7 +142,7 @@ error: */ static int DetectStreamSizeSetup (DetectEngineCtx *de_ctx, Signature *s, const char *streamstr) { - DetectStreamSizeData *sd = DetectStreamSizeParse(de_ctx, streamstr); + DetectStreamSizeData *sd = rs_detect_stream_size_parse(streamstr); if (sd == NULL) return -1; @@ -316,8 +166,70 @@ static int DetectStreamSizeSetup (DetectEngineCtx *de_ctx, Signature *s, const c */ void DetectStreamSizeFree(DetectEngineCtx *de_ctx, void *ptr) { - DetectStreamSizeData *sd = (DetectStreamSizeData *)ptr; - SCFree(sd); + rs_detect_stream_size_free(ptr); +} + +/* prefilter code */ + +static void PrefilterPacketStreamsizeMatch( + DetectEngineThreadCtx *det_ctx, Packet *p, const void *pectx) +{ + if (!(PKT_IS_TCP(p)) || PKT_IS_PSEUDOPKT(p)) + return; + + if (p->flow == NULL || p->flow->protoctx == NULL) + return; + + /* during setup Suricata will automatically see if there is another + * check that can be added: alproto, sport or dport */ + const PrefilterPacketHeaderCtx *ctx = pectx; + if (!PrefilterPacketHeaderExtraMatch(ctx, p)) + return; + + DetectStreamSizeData dsd; + dsd.du32.mode = ctx->v1.u8[0]; + dsd.flags = ctx->v1.u8[1]; + dsd.du32.arg1 = ctx->v1.u32[2]; + const TcpSession *ssn = (TcpSession *)p->flow->protoctx; + /* if we match, add all the sigs that use this prefilter. This means + * that these will be inspected further */ + if (DetectStreamSizeMatchAux(&dsd, ssn)) { + PrefilterAddSids(&det_ctx->pmq, ctx->sigs_array, ctx->sigs_cnt); + } +} + +static void PrefilterPacketStreamSizeSet(PrefilterPacketHeaderValue *v, void *smctx) +{ + const DetectStreamSizeData *a = smctx; + v->u8[0] = a->du32.mode; + v->u8[1] = a->flags; + v->u32[2] = a->du32.arg1; +} + +static bool PrefilterPacketStreamSizeCompare(PrefilterPacketHeaderValue v, void *smctx) +{ + const DetectStreamSizeData *a = smctx; + if (v.u8[0] == a->du32.mode && v.u8[1] == a->flags && v.u32[2] == a->du32.arg1) + return true; + return false; +} + +static int PrefilterSetupStreamSize(DetectEngineCtx *de_ctx, SigGroupHead *sgh) +{ + return PrefilterSetupPacketHeader(de_ctx, sgh, DETECT_TCPMSS, PrefilterPacketStreamSizeSet, + PrefilterPacketStreamSizeCompare, PrefilterPacketStreamsizeMatch); +} + +static bool PrefilterStreamSizeIsPrefilterable(const Signature *s) +{ + const SigMatch *sm; + for (sm = s->init_data->smlists[DETECT_SM_LIST_MATCH]; sm != NULL; sm = sm->next) { + switch (sm->type) { + case DETECT_STREAM_SIZE: + return true; + } + } + return false; } #ifdef UNITTESTS @@ -330,9 +242,9 @@ static int DetectStreamSizeParseTest01 (void) { int result = 0; DetectStreamSizeData *sd = NULL; - sd = DetectStreamSizeParse(NULL, "server,<,6"); + sd = rs_detect_stream_size_parse("server,<,6"); if (sd != NULL) { - if (sd->flags & STREAM_SIZE_SERVER && sd->mode == DETECTSSIZE_LT && sd->ssize == 6) + if (sd->flags & StreamSizeServer && sd->du32.mode == DETECT_UINT_LT && sd->du32.arg1 == 6) result = 1; DetectStreamSizeFree(NULL, sd); } @@ -349,9 +261,9 @@ static int DetectStreamSizeParseTest02 (void) { int result = 1; DetectStreamSizeData *sd = NULL; - sd = DetectStreamSizeParse(NULL, "invalidoption,<,6"); + sd = rs_detect_stream_size_parse("invalidoption,<,6"); if (sd != NULL) { - printf("expected: NULL got 0x%02X %" PRIu32 ": ",sd->flags, sd->ssize); + printf("expected: NULL got 0x%02X %" PRIu32 ": ", sd->flags, sd->du32.arg1); result = 0; DetectStreamSizeFree(NULL, sd); } @@ -390,24 +302,24 @@ static int DetectStreamSizeParseTest03 (void) memset(&f, 0, sizeof(Flow)); memset(&tcph, 0, sizeof(TCPHdr)); - sd = DetectStreamSizeParse(NULL, "client,>,8"); + sd = rs_detect_stream_size_parse("client,>,8"); if (sd != NULL) { - if (!(sd->flags & STREAM_SIZE_CLIENT)) { + if (!(sd->flags & StreamSizeClient)) { printf("sd->flags not STREAM_SIZE_CLIENT: "); DetectStreamSizeFree(NULL, sd); SCFree(p); return 0; } - if (sd->mode != DETECTSSIZE_GT) { + if (sd->du32.mode != DETECT_UINT_GT) { printf("sd->mode not DETECTSSIZE_GT: "); DetectStreamSizeFree(NULL, sd); SCFree(p); return 0; } - if (sd->ssize != 8) { - printf("sd->ssize is %"PRIu32", not 8: ", sd->ssize); + if (sd->du32.arg1 != 8) { + printf("sd->ssize is %" PRIu32 ", not 8: ", sd->du32.arg1); DetectStreamSizeFree(NULL, sd); SCFree(p); return 0; @@ -466,11 +378,12 @@ static int DetectStreamSizeParseTest04 (void) memset(&f, 0, sizeof(Flow)); memset(&ip4h, 0, sizeof(IPV4Hdr)); - sd = DetectStreamSizeParse(NULL, " client , > , 8 "); + sd = rs_detect_stream_size_parse(" client , > , 8 "); if (sd != NULL) { - if (!(sd->flags & STREAM_SIZE_CLIENT) && sd->mode != DETECTSSIZE_GT && sd->ssize != 8) { - SCFree(p); - return 0; + if (!(sd->flags & StreamSizeClient) && sd->du32.mode != DETECT_UINT_GT && + sd->du32.arg1 != 8) { + SCFree(p); + return 0; } } else { diff --git a/src/detect-stream_size.h b/src/detect-stream_size.h index 32f5c50b19..3a460bf5e2 100644 --- a/src/detect-stream_size.h +++ b/src/detect-stream_size.h @@ -24,24 +24,6 @@ #ifndef _DETECT_STREAM_SIZE_H #define _DETECT_STREAM_SIZE_H -#define DETECTSSIZE_LT 0 -#define DETECTSSIZE_LEQ 1 -#define DETECTSSIZE_EQ 2 -#define DETECTSSIZE_NEQ 3 -#define DETECTSSIZE_GT 4 -#define DETECTSSIZE_GEQ 5 - -#define STREAM_SIZE_SERVER 0x01 -#define STREAM_SIZE_CLIENT 0x02 -#define STREAM_SIZE_BOTH 0x04 -#define STREAM_SIZE_EITHER 0x08 - -typedef struct DetectStreamSizeData_ { - uint8_t flags; - uint8_t mode; - uint32_t ssize; -}DetectStreamSizeData; - void DetectStreamSizeRegister(void); #endif /* _DETECT_STREAM_SIZE_H */