mqtt: ensure we do not request extra data after buffering

This addresses Redmine bug #5018 by ensuring that the parser
never requests additional data via the Incomplete error, but to
raise an actual parse error, since it is supposed to have all
the data as specified by the message length in the header already.
pull/7223/head
Sascha Steinbiss 4 years ago committed by Victor Julien
parent e3180e3248
commit 5618273ef4

@ -30,7 +30,7 @@ use nom7::sequence::tuple;
use nom7::{Err, IResult, Needed}; use nom7::{Err, IResult, Needed};
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
#[derive(Debug)] #[derive(Copy, Clone, Debug)]
pub struct FixedHeader { pub struct FixedHeader {
pub message_type: MQTTTypeCode, pub message_type: MQTTTypeCode,
pub dup_flag: bool, pub dup_flag: bool,
@ -222,7 +222,11 @@ pub fn parse_connect(i: &[u8]) -> IResult<&[u8], MQTTConnectData> {
)) ))
} }
pub fn parse_connack(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTConnackData> { #[inline]
pub fn parse_connack(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTConnackData>
where
{
move |i: &[u8]| {
let (i, topic_name_compression_response) = be_u8(i)?; let (i, topic_name_compression_response) = be_u8(i)?;
let (i, return_code) = be_u8(i)?; let (i, return_code) = be_u8(i)?;
let (i, properties) = parse_properties(i, protocol_version == 5)?; let (i, properties) = parse_properties(i, protocol_version == 5)?;
@ -234,11 +238,16 @@ pub fn parse_connack(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTConna
properties, properties,
}, },
)) ))
}
} }
pub fn parse_publish( #[inline]
i: &[u8], protocol_version: u8, has_id: bool, fn parse_publish(
) -> IResult<&[u8], MQTTPublishData> { protocol_version: u8, has_id: bool,
) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTPublishData>
where
{
move |i: &[u8]| {
let (i, topic) = parse_mqtt_string(i)?; let (i, topic) = parse_mqtt_string(i)?;
let (i, message_id) = cond(has_id, be_u16)(i)?; let (i, message_id) = cond(has_id, be_u16)(i)?;
let (message, properties) = parse_properties(i, protocol_version == 5)?; let (message, properties) = parse_properties(i, protocol_version == 5)?;
@ -251,10 +260,14 @@ pub fn parse_publish(
properties, properties,
}, },
)) ))
}
} }
#[inline] #[inline]
fn parse_msgidonly(input: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTMessageIdOnly> { pub fn parse_msgidonly(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTMessageIdOnly>
where
{
move |input: &[u8]| {
if protocol_version < 5 { if protocol_version < 5 {
// before v5 we don't even have to care about reason codes // before v5 we don't even have to care about reason codes
// and properties, lucky us // and properties, lucky us
@ -311,6 +324,7 @@ fn parse_msgidonly(input: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTMes
} }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
}
} }
#[inline] #[inline]
@ -333,7 +347,11 @@ pub fn parse_subscribe_topic(i: &[u8]) -> IResult<&[u8], MQTTSubscribeTopicData>
Ok((i, MQTTSubscribeTopicData { topic_name, qos })) Ok((i, MQTTSubscribeTopicData { topic_name, qos }))
} }
pub fn parse_subscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSubscribeData> { #[inline]
pub fn parse_subscribe(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTSubscribeData>
where
{
move |i: &[u8]| {
let (i, message_id) = be_u16(i)?; let (i, message_id) = be_u16(i)?;
let (i, properties) = parse_properties(i, protocol_version == 5)?; let (i, properties) = parse_properties(i, protocol_version == 5)?;
let (i, topics) = many1(complete(parse_subscribe_topic))(i)?; let (i, topics) = many1(complete(parse_subscribe_topic))(i)?;
@ -345,9 +363,14 @@ pub fn parse_subscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSub
properties, properties,
}, },
)) ))
}
} }
pub fn parse_suback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSubackData> { #[inline]
pub fn parse_suback(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTSubackData>
where
{
move |i: &[u8]| {
let (i, message_id) = be_u16(i)?; let (i, message_id) = be_u16(i)?;
let (qoss, properties) = parse_properties(i, protocol_version == 5)?; let (qoss, properties) = parse_properties(i, protocol_version == 5)?;
Ok(( Ok((
@ -358,9 +381,16 @@ pub fn parse_suback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTSuback
properties, properties,
}, },
)) ))
}
} }
pub fn parse_unsubscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTUnsubscribeData> { #[inline]
pub fn parse_unsubscribe(
protocol_version: u8,
) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTUnsubscribeData>
where
{
move |i: &[u8]| {
let (i, message_id) = be_u16(i)?; let (i, message_id) = be_u16(i)?;
let (i, properties) = parse_properties(i, protocol_version == 5)?; let (i, properties) = parse_properties(i, protocol_version == 5)?;
let (i, topics) = many0(complete(parse_mqtt_string))(i)?; let (i, topics) = many0(complete(parse_mqtt_string))(i)?;
@ -372,9 +402,14 @@ pub fn parse_unsubscribe(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTU
properties, properties,
}, },
)) ))
}
} }
pub fn parse_unsuback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTUnsubackData> { #[inline]
pub fn parse_unsuback(protocol_version: u8) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTUnsubackData>
where
{
move |i: &[u8]| {
let (i, message_id) = be_u16(i)?; let (i, message_id) = be_u16(i)?;
let (i, properties) = parse_properties(i, protocol_version == 5)?; let (i, properties) = parse_properties(i, protocol_version == 5)?;
let (i, reason_codes) = many0(complete(be_u8))(i)?; let (i, reason_codes) = many0(complete(be_u8))(i)?;
@ -386,12 +421,16 @@ pub fn parse_unsuback(i: &[u8], protocol_version: u8) -> IResult<&[u8], MQTTUnsu
reason_codes: Some(reason_codes), reason_codes: Some(reason_codes),
}, },
)) ))
}
} }
#[inline] #[inline]
fn parse_disconnect( fn parse_disconnect(
input: &[u8], remaining_len: usize, protocol_version: u8, remaining_len: usize, protocol_version: u8,
) -> IResult<&[u8], MQTTDisconnectData> { ) -> impl Fn(&[u8]) -> IResult<&[u8], MQTTDisconnectData>
where
{
move |input: &[u8]| {
if protocol_version < 5 { if protocol_version < 5 {
return Ok(( return Ok((
input, input,
@ -442,6 +481,7 @@ fn parse_disconnect(
} }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
}
} }
#[inline] #[inline]
@ -457,81 +497,43 @@ pub fn parse_auth(i: &[u8]) -> IResult<&[u8], MQTTAuthData> {
)) ))
} }
pub fn parse_message( #[inline]
input: &[u8], protocol_version: u8, max_msg_size: usize, fn parse_remaining_message<'a>(
) -> IResult<&[u8], MQTTMessage> { full: &'a [u8], len: usize, skiplen: usize, header: FixedHeader, message_type: MQTTTypeCode,
// Parse the fixed header first. This is identical across versions and can protocol_version: u8,
// be between 2 and 5 bytes long. ) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], MQTTMessage>
match parse_fixed_header(input) { where
Ok((fullrem, header)) => { {
let len = header.remaining_length as usize; move |input: &'a [u8]| {
// This is the length of the fixed header that we need to skip
// before returning the remainder. It is the sum of the length
// of the flag byte (1) and the length of the message length
// varint.
let skiplen = input.len() - fullrem.len();
let message_type = header.message_type;
// If the remaining length (message length) exceeds the specified
// limit, we return a special truncation message type, containing
// no parsed metadata but just the skipped length and the message
// type.
if len > max_msg_size {
let msg = MQTTMessage {
header,
op: MQTTOperation::TRUNCATED(MQTTTruncatedData {
original_message_type: message_type,
skipped_length: len + skiplen,
}),
};
// In this case we return the full input buffer, since this is
// what the skipped_length value also refers to: header _and_
// remaining length.
return Ok((input, msg));
}
// We have not exceeded the maximum length limit, but still do not
// have enough data in the input buffer to handle the full
// message. Signal this by returning an Incomplete IResult value.
if fullrem.len() < len {
return Err(Err::Incomplete(Needed::new(len - fullrem.len())));
}
// Parse the contents of the buffer into a single message.
// We reslice the remainder into the portion that we are interested
// in, according to the length we just parsed. This helps with the
// complete! parsers, where we would otherwise need to keep track
// of the already parsed length.
let rem = &fullrem[..len];
match message_type { match message_type {
MQTTTypeCode::CONNECT => match parse_connect(rem) { MQTTTypeCode::CONNECT => match parse_connect(input) {
Ok((_rem, conn)) => { Ok((_rem, conn)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::CONNECT(conn), op: MQTTOperation::CONNECT(conn),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::CONNACK => match parse_connack(rem, protocol_version) { MQTTTypeCode::CONNACK => match parse_connack(protocol_version)(input) {
Ok((_rem, connack)) => { Ok((_rem, connack)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::CONNACK(connack), op: MQTTOperation::CONNACK(connack),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::PUBLISH => { MQTTTypeCode::PUBLISH => {
match parse_publish(rem, protocol_version, header.qos_level > 0) { match parse_publish(protocol_version, header.qos_level > 0)(input) {
Ok((_rem, publish)) => { Ok((_rem, publish)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::PUBLISH(publish), op: MQTTOperation::PUBLISH(publish),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
@ -539,7 +541,7 @@ pub fn parse_message(
MQTTTypeCode::PUBACK MQTTTypeCode::PUBACK
| MQTTTypeCode::PUBREC | MQTTTypeCode::PUBREC
| MQTTTypeCode::PUBREL | MQTTTypeCode::PUBREL
| MQTTTypeCode::PUBCOMP => match parse_msgidonly(rem, protocol_version) { | MQTTTypeCode::PUBCOMP => match parse_msgidonly(protocol_version)(input) {
Ok((_rem, msgidonly)) => { Ok((_rem, msgidonly)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
@ -551,47 +553,47 @@ pub fn parse_message(
_ => MQTTOperation::UNASSIGNED, _ => MQTTOperation::UNASSIGNED,
}, },
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::SUBSCRIBE => match parse_subscribe(rem, protocol_version) { MQTTTypeCode::SUBSCRIBE => match parse_subscribe(protocol_version)(input) {
Ok((_rem, subs)) => { Ok((_rem, subs)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::SUBSCRIBE(subs), op: MQTTOperation::SUBSCRIBE(subs),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::SUBACK => match parse_suback(rem, protocol_version) { MQTTTypeCode::SUBACK => match parse_suback(protocol_version)(input) {
Ok((_rem, suback)) => { Ok((_rem, suback)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::SUBACK(suback), op: MQTTOperation::SUBACK(suback),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::UNSUBSCRIBE => match parse_unsubscribe(rem, protocol_version) { MQTTTypeCode::UNSUBSCRIBE => match parse_unsubscribe(protocol_version)(input) {
Ok((_rem, unsub)) => { Ok((_rem, unsub)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::UNSUBSCRIBE(unsub), op: MQTTOperation::UNSUBSCRIBE(unsub),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::UNSUBACK => match parse_unsuback(rem, protocol_version) { MQTTTypeCode::UNSUBACK => match parse_unsuback(protocol_version)(input) {
Ok((_rem, unsuback)) => { Ok((_rem, unsuback)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::UNSUBACK(unsuback), op: MQTTOperation::UNSUBACK(unsuback),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
@ -604,25 +606,25 @@ pub fn parse_message(
_ => MQTTOperation::UNASSIGNED, _ => MQTTOperation::UNASSIGNED,
}, },
}; };
return Ok((&input[skiplen + len..], msg)); Ok((&full[skiplen + len..], msg))
} }
MQTTTypeCode::DISCONNECT => match parse_disconnect(rem, len, protocol_version) { MQTTTypeCode::DISCONNECT => match parse_disconnect(len, protocol_version)(input) {
Ok((_rem, disco)) => { Ok((_rem, disco)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::DISCONNECT(disco), op: MQTTOperation::DISCONNECT(disco),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
MQTTTypeCode::AUTH => match parse_auth(rem) { MQTTTypeCode::AUTH => match parse_auth(input) {
Ok((_rem, auth)) => { Ok((_rem, auth)) => {
let msg = MQTTMessage { let msg = MQTTMessage {
header, header,
op: MQTTOperation::AUTH(auth), op: MQTTOperation::AUTH(auth),
}; };
Ok((&input[skiplen + len..], msg)) Ok((&full[skiplen + len..], msg))
} }
Err(e) => Err(e), Err(e) => Err(e),
}, },
@ -634,9 +636,71 @@ pub fn parse_message(
header, header,
op: MQTTOperation::UNASSIGNED, op: MQTTOperation::UNASSIGNED,
}; };
return Ok((&rem[len..], msg)); return Ok((&full[skiplen + len..], msg));
}
}
}
}
pub fn parse_message(
input: &[u8], protocol_version: u8, max_msg_size: usize,
) -> IResult<&[u8], MQTTMessage> {
// Parse the fixed header first. This is identical across versions and can
// be between 2 and 5 bytes long.
match parse_fixed_header(input) {
Ok((fullrem, header)) => {
let len = header.remaining_length as usize;
// This is the length of the fixed header that we need to skip
// before returning the remainder. It is the sum of the length
// of the flag byte (1) and the length of the message length
// varint.
let skiplen = input.len() - fullrem.len();
let message_type = header.message_type;
// If the remaining length (message length) exceeds the specified
// limit, we return a special truncation message type, containing
// no parsed metadata but just the skipped length and the message
// type.
if len > max_msg_size {
let msg = MQTTMessage {
header,
op: MQTTOperation::TRUNCATED(MQTTTruncatedData {
original_message_type: message_type,
skipped_length: len + skiplen,
}),
};
// In this case we return the full input buffer, since this is
// what the skipped_length value also refers to: header _and_
// remaining length.
return Ok((input, msg));
} }
// We have not exceeded the maximum length limit, but still do not
// have enough data in the input buffer to handle the full
// message. Signal this by returning an Incomplete IResult value.
if fullrem.len() < len {
return Err(Err::Incomplete(Needed::new(len - fullrem.len())));
} }
// Parse the contents of the buffer into a single message.
// We reslice the remainder into the portion that we are interested
// in, according to the length we just parsed. This helps with the
// complete() parsers, where we would otherwise need to keep track
// of the already parsed length.
let rem = &fullrem[..len];
// Parse remaining message in buffer. We use complete() to ensure
// we do not request additional content in case of incomplete
// parsing, but raise an error instead as we should have all the
// data asked for in the header.
return complete(parse_remaining_message(
input,
len,
skiplen,
header,
message_type,
protocol_version,
))(rem);
} }
Err(err) => { Err(err) => {
return Err(err); return Err(err);

Loading…
Cancel
Save