diff --git a/rust/src/mqtt/parser.rs b/rust/src/mqtt/parser.rs index d252290733..9a547286c0 100644 --- a/rust/src/mqtt/parser.rs +++ b/rust/src/mqtt/parser.rs @@ -30,7 +30,7 @@ use nom7::sequence::tuple; use nom7::{Err, IResult, Needed}; use num_traits::FromPrimitive; -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct FixedHeader { pub message_type: MQTTTypeCode, pub dup_flag: bool, @@ -222,94 +222,108 @@ pub fn parse_connect(i: &[u8]) -> IResult<&[u8], MQTTConnectData> { )) } -pub fn parse_connack(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTConnackData> { - let (i, topic_name_compression_response) = be_u8(i)?; - let (i, return_code) = be_u8(i)?; - let (i, properties) = parse_properties(i, protocol_version == 5)?; - Ok(( - i, - MQTTConnackData { - session_present: (topic_name_compression_response & 1) != 0, - return_code, - properties, - }, - )) +#[inline] +pub fn parse_connack(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTConnackData> +where +{ + move |i: &[u8]| { + let (i, topic_name_compression_response) = be_u8(i)?; + let (i, return_code) = be_u8(i)?; + let (i, properties) = parse_properties(i, protocol_version == 5)?; + Ok(( + i, + MQTTConnackData { + session_present: (topic_name_compression_response & 1) != 0, + return_code, + properties, + }, + )) + } } -pub fn parse_publish( - i: &[u8], protocol_version: u8, has_id: bool, -) -> IResult<&[u8], MQTTPublishData> { - let (i, topic) = parse_mqtt_string(i)?; - let (i, message_id) = cond(has_id, be_u16)(i)?; - let (message, properties) = parse_properties(i, protocol_version == 5)?; - Ok(( - i, - MQTTPublishData { - topic, - message_id, - message: message.to_vec(), - properties, - }, - )) +#[inline] +fn parse_publish( + protocol_version: u8, has_id: bool, +) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTPublishData> +where +{ + move |i: &[u8]| { + let (i, topic) = parse_mqtt_string(i)?; + let (i, message_id) = cond(has_id, be_u16)(i)?; + let (message, properties) = parse_properties(i, protocol_version == 5)?; + Ok(( + i, + MQTTPublishData { + topic, + message_id, + message: message.to_vec(), + properties, + }, + )) + } } #[inline] -fn parse_msgidonly(input: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTMessageIdOnly> { - if protocol_version < 5 { - // before v5 we don't even have to care about reason codes - // and properties, lucky us - return parse_msgidonly_v3(input); - } - let remaining_len = input.len(); - match be_u16(input) { - Ok((rem, message_id)) => { - if remaining_len == 2 { - // from the spec: " The Reason Code and Property Length can be - // omitted if the Reason Code is 0x00 (Success) and there are - // no Properties. In this case the message has a Remaining - // Length of 2." - return Ok(( - rem, - MQTTMessageIdOnly { - message_id, - reason_code: Some(0), - properties: None, - }, - )); - } - match be_u8(rem) { - Ok((rem, reason_code)) => { - // We are checking for 3 because in that case we have a - // header plus reason code, but no properties. - if remaining_len == 3 { - // no properties - return Ok(( - rem, - MQTTMessageIdOnly { - message_id, - reason_code: Some(reason_code), - properties: None, - }, - )); - } - match parse_properties(rem, true) { - Ok((rem, properties)) => { +pub fn parse_msgidonly(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTMessageIdOnly> +where +{ + move |input: &[u8]| { + if protocol_version < 5 { + // before v5 we don't even have to care about reason codes + // and properties, lucky us + return parse_msgidonly_v3(input); + } + let remaining_len = input.len(); + match be_u16(input) { + Ok((rem, message_id)) => { + if remaining_len == 2 { + // from the spec: " The Reason Code and Property Length can be + // omitted if the Reason Code is 0x00 (Success) and there are + // no Properties. In this case the message has a Remaining + // Length of 2." + return Ok(( + rem, + MQTTMessageIdOnly { + message_id, + reason_code: Some(0), + properties: None, + }, + )); + } + match be_u8(rem) { + Ok((rem, reason_code)) => { + // We are checking for 3 because in that case we have a + // header plus reason code, but no properties. + if remaining_len == 3 { + // no properties return Ok(( rem, MQTTMessageIdOnly { message_id, reason_code: Some(reason_code), - properties, + properties: None, }, )); } - Err(e) => return Err(e), + match parse_properties(rem, true) { + Ok((rem, properties)) => { + return Ok(( + rem, + MQTTMessageIdOnly { + message_id, + reason_code: Some(reason_code), + properties, + }, + )); + } + Err(e) => return Err(e), + } } + Err(e) => return Err(e), } - Err(e) => return Err(e), } + Err(e) => return Err(e), } - Err(e) => return Err(e), } } @@ -333,114 +347,140 @@ pub fn parse_subscribe_topic(i: &[u8]) -> IResult<&[u8], MQTTSubscribeTopicData> Ok((i, MQTTSubscribeTopicData { topic_name, qos })) } -pub fn parse_subscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSubscribeData> { - let (i, message_id) = be_u16(i)?; - let (i, properties) = parse_properties(i, protocol_version == 5)?; - let (i, topics) = many1(complete(parse_subscribe_topic))(i)?; - Ok(( - i, - MQTTSubscribeData { - message_id, - topics, - properties, - }, - )) +#[inline] +pub fn parse_subscribe(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTSubscribeData> +where +{ + move |i: &[u8]| { + let (i, message_id) = be_u16(i)?; + let (i, properties) = parse_properties(i, protocol_version == 5)?; + let (i, topics) = many1(complete(parse_subscribe_topic))(i)?; + Ok(( + i, + MQTTSubscribeData { + message_id, + topics, + properties, + }, + )) + } } -pub fn parse_suback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSubackData> { - let (i, message_id) = be_u16(i)?; - let (qoss, properties) = parse_properties(i, protocol_version == 5)?; - Ok(( - i, - MQTTSubackData { - message_id, - qoss: qoss.to_vec(), - properties, - }, - )) +#[inline] +pub fn parse_suback(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTSubackData> +where +{ + move |i: &[u8]| { + let (i, message_id) = be_u16(i)?; + let (qoss, properties) = parse_properties(i, protocol_version == 5)?; + Ok(( + i, + MQTTSubackData { + message_id, + qoss: qoss.to_vec(), + properties, + }, + )) + } } -pub fn parse_unsubscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTUnsubscribeData> { - let (i, message_id) = be_u16(i)?; - let (i, properties) = parse_properties(i, protocol_version == 5)?; - let (i, topics) = many0(complete(parse_mqtt_string))(i)?; - Ok(( - i, - MQTTUnsubscribeData { - message_id, - topics, - properties, - }, - )) +#[inline] +pub fn parse_unsubscribe( + protocol_version: u8, +) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTUnsubscribeData> +where +{ + move |i: &[u8]| { + let (i, message_id) = be_u16(i)?; + let (i, properties) = parse_properties(i, protocol_version == 5)?; + let (i, topics) = many0(complete(parse_mqtt_string))(i)?; + Ok(( + i, + MQTTUnsubscribeData { + message_id, + topics, + properties, + }, + )) + } } -pub fn parse_unsuback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTUnsubackData> { - let (i, message_id) = be_u16(i)?; - let (i, properties) = parse_properties(i, protocol_version == 5)?; - let (i, reason_codes) = many0(complete(be_u8))(i)?; - Ok(( - i, - MQTTUnsubackData { - message_id, - properties, - reason_codes: Some(reason_codes), - }, - )) +#[inline] +pub fn parse_unsuback(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTUnsubackData> +where +{ + move |i: &[u8]| { + let (i, message_id) = be_u16(i)?; + let (i, properties) = parse_properties(i, protocol_version == 5)?; + let (i, reason_codes) = many0(complete(be_u8))(i)?; + Ok(( + i, + MQTTUnsubackData { + message_id, + properties, + reason_codes: Some(reason_codes), + }, + )) + } } #[inline] fn parse_disconnect( - input: &[u8], remaining_len: usize, protocol_version: u8, -) -> IResult<&[u8], MQTTDisconnectData> { - if protocol_version < 5 { - return Ok(( - input, - MQTTDisconnectData { - reason_code: None, - properties: None, - }, - )); - } - if remaining_len == 0 { - // The Reason Code and Property Length can be omitted if the Reason - // Code is 0x00 (Normal disconnection) and there are no Properties. - // In this case the DISCONNECT has a Remaining Length of 0. - return Ok(( - input, - MQTTDisconnectData { - reason_code: Some(0), - properties: None, - }, - )); - } - match be_u8(input) { - Ok((rem, reason_code)) => { - // We are checking for 1 because in that case we have a - // header plus reason code, but no properties. - if remaining_len == 1 { - // no properties - return Ok(( - rem, - MQTTDisconnectData { - reason_code: Some(0), - properties: None, - }, - )); - } - match parse_properties(rem, true) { - Ok((rem, properties)) => { + remaining_len: usize, protocol_version: u8, +) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTDisconnectData> +where +{ + move |input: &[u8]| { + if protocol_version < 5 { + return Ok(( + input, + MQTTDisconnectData { + reason_code: None, + properties: None, + }, + )); + } + if remaining_len == 0 { + // The Reason Code and Property Length can be omitted if the Reason + // Code is 0x00 (Normal disconnection) and there are no Properties. + // In this case the DISCONNECT has a Remaining Length of 0. + return Ok(( + input, + MQTTDisconnectData { + reason_code: Some(0), + properties: None, + }, + )); + } + match be_u8(input) { + Ok((rem, reason_code)) => { + // We are checking for 1 because in that case we have a + // header plus reason code, but no properties. + if remaining_len == 1 { + // no properties return Ok(( rem, MQTTDisconnectData { - reason_code: Some(reason_code), - properties, + reason_code: Some(0), + properties: None, }, )); } - Err(e) => return Err(e), + match parse_properties(rem, true) { + Ok((rem, properties)) => { + return Ok(( + rem, + MQTTDisconnectData { + reason_code: Some(reason_code), + properties, + }, + )); + } + Err(e) => return Err(e), + } } + Err(e) => return Err(e), } - Err(e) => return Err(e), } } @@ -457,6 +497,151 @@ pub fn parse_auth(i: &[u8]) -> IResult<&[u8], MQTTAuthData> { )) } +#[inline] +fn parse_remaining_message<'a>( + full: &'a [u8], len: usize, skiplen: usize, header: FixedHeader, message_type: MQTTTypeCode, + protocol_version: u8, +) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], MQTTMessage> +where +{ + move |input: &'a [u8]| { + match message_type { + MQTTTypeCode::CONNECT => match parse_connect(input) { + Ok((_rem, conn)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::CONNECT(conn), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::CONNACK => match parse_connack(protocol_version)(input) { + Ok((_rem, connack)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::CONNACK(connack), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::PUBLISH => { + match parse_publish(protocol_version, header.qos_level > 0)(input) { + Ok((_rem, publish)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::PUBLISH(publish), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + } + } + MQTTTypeCode::PUBACK + | MQTTTypeCode::PUBREC + | MQTTTypeCode::PUBREL + | MQTTTypeCode::PUBCOMP => match parse_msgidonly(protocol_version)(input) { + Ok((_rem, msgidonly)) => { + let msg = MQTTMessage { + header, + op: match message_type { + MQTTTypeCode::PUBACK => MQTTOperation::PUBACK(msgidonly), + MQTTTypeCode::PUBREC => MQTTOperation::PUBREC(msgidonly), + MQTTTypeCode::PUBREL => MQTTOperation::PUBREL(msgidonly), + MQTTTypeCode::PUBCOMP => MQTTOperation::PUBCOMP(msgidonly), + _ => MQTTOperation::UNASSIGNED, + }, + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::SUBSCRIBE => match parse_subscribe(protocol_version)(input) { + Ok((_rem, subs)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::SUBSCRIBE(subs), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::SUBACK => match parse_suback(protocol_version)(input) { + Ok((_rem, suback)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::SUBACK(suback), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::UNSUBSCRIBE => match parse_unsubscribe(protocol_version)(input) { + Ok((_rem, unsub)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::UNSUBSCRIBE(unsub), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::UNSUBACK => match parse_unsuback(protocol_version)(input) { + Ok((_rem, unsuback)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::UNSUBACK(unsuback), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::PINGREQ | MQTTTypeCode::PINGRESP => { + let msg = MQTTMessage { + header, + op: match message_type { + MQTTTypeCode::PINGREQ => MQTTOperation::PINGREQ, + MQTTTypeCode::PINGRESP => MQTTOperation::PINGRESP, + _ => MQTTOperation::UNASSIGNED, + }, + }; + Ok((&full[skiplen + len..], msg)) + } + MQTTTypeCode::DISCONNECT => match parse_disconnect(len, protocol_version)(input) { + Ok((_rem, disco)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::DISCONNECT(disco), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + MQTTTypeCode::AUTH => match parse_auth(input) { + Ok((_rem, auth)) => { + let msg = MQTTMessage { + header, + op: MQTTOperation::AUTH(auth), + }; + Ok((&full[skiplen + len..], msg)) + } + Err(e) => Err(e), + }, + // Unassigned message type code. Unlikely to happen with + // regular traffic, might be an indication for broken or + // crafted MQTT traffic. + _ => { + let msg = MQTTMessage { + header, + op: MQTTOperation::UNASSIGNED, + }; + return Ok((&full[skiplen + len..], msg)); + } + } + } +} + pub fn parse_message( input: &[u8], protocol_version: u8, max_msg_size: usize, ) -> IResult<&[u8], MQTTMessage> { @@ -500,143 +685,22 @@ pub fn parse_message( // Parse the contents of the buffer into a single message. // We reslice the remainder into the portion that we are interested // in, according to the length we just parsed. This helps with the - // complete! parsers, where we would otherwise need to keep track + // complete() parsers, where we would otherwise need to keep track // of the already parsed length. let rem = &fullrem[..len]; - match message_type { - MQTTTypeCode::CONNECT => match parse_connect(rem) { - Ok((_rem, conn)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::CONNECT(conn), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::CONNACK => match parse_connack(rem, protocol_version) { - Ok((_rem, connack)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::CONNACK(connack), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::PUBLISH => { - match parse_publish(rem, protocol_version, header.qos_level > 0) { - Ok((_rem, publish)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::PUBLISH(publish), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - } - } - MQTTTypeCode::PUBACK - | MQTTTypeCode::PUBREC - | MQTTTypeCode::PUBREL - | MQTTTypeCode::PUBCOMP => match parse_msgidonly(rem, protocol_version) { - Ok((_rem, msgidonly)) => { - let msg = MQTTMessage { - header, - op: match message_type { - MQTTTypeCode::PUBACK => MQTTOperation::PUBACK(msgidonly), - MQTTTypeCode::PUBREC => MQTTOperation::PUBREC(msgidonly), - MQTTTypeCode::PUBREL => MQTTOperation::PUBREL(msgidonly), - MQTTTypeCode::PUBCOMP => MQTTOperation::PUBCOMP(msgidonly), - _ => MQTTOperation::UNASSIGNED, - }, - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::SUBSCRIBE => match parse_subscribe(rem, protocol_version) { - Ok((_rem, subs)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::SUBSCRIBE(subs), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::SUBACK => match parse_suback(rem, protocol_version) { - Ok((_rem, suback)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::SUBACK(suback), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::UNSUBSCRIBE => match parse_unsubscribe(rem, protocol_version) { - Ok((_rem, unsub)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::UNSUBSCRIBE(unsub), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::UNSUBACK => match parse_unsuback(rem, protocol_version) { - Ok((_rem, unsuback)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::UNSUBACK(unsuback), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::PINGREQ | MQTTTypeCode::PINGRESP => { - let msg = MQTTMessage { - header, - op: match message_type { - MQTTTypeCode::PINGREQ => MQTTOperation::PINGREQ, - MQTTTypeCode::PINGRESP => MQTTOperation::PINGRESP, - _ => MQTTOperation::UNASSIGNED, - }, - }; - return Ok((&input[skiplen + len..], msg)); - } - MQTTTypeCode::DISCONNECT => match parse_disconnect(rem, len, protocol_version) { - Ok((_rem, disco)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::DISCONNECT(disco), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - MQTTTypeCode::AUTH => match parse_auth(rem) { - Ok((_rem, auth)) => { - let msg = MQTTMessage { - header, - op: MQTTOperation::AUTH(auth), - }; - Ok((&input[skiplen + len..], msg)) - } - Err(e) => Err(e), - }, - // Unassigned message type code. Unlikely to happen with - // regular traffic, might be an indication for broken or - // crafted MQTT traffic. - _ => { - let msg = MQTTMessage { - header, - op: MQTTOperation::UNASSIGNED, - }; - return Ok((&rem[len..], msg)); - } - } + + // Parse remaining message in buffer. We use complete() to ensure + // we do not request additional content in case of incomplete + // parsing, but raise an error instead as we should have all the + // data asked for in the header. + return complete(parse_remaining_message( + input, + len, + skiplen, + header, + message_type, + protocol_version, + ))(rem); } Err(err) => { return Err(err);