diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c77aac4 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +# 0.2 (unreleased) + +This is a fairly large release with API fixes and improvements, bug fixes, and much better test +coverage and documentation. + +## API changes + +* We now return a dedicated error type instead of abusing `std::io::Error`. +* `PacketIdentifier` was renamed to `Pid`. It now avoids the illegal value 0, wraps around automatically, and can be hashed. +* `Publish.qos` and `Publish.pid` have been merged together, avoiding accidental illegal combinations. +* `Connect.password` and `Connect.will.payload` can now contain binary data. +* The `Protocol` enum doesn't carry extra data anymore. +* All public structs/enum/functions are now (re)exported from the crate root, and the rest has been made private. +* The letter-casing of packet types is more consistent. +* Packet subtypes can be converted to `Packet` using `.into()`. + +## Other changes + +* Much improved documentation. See it with `cargo doc --open`. +* More thorough unittesting, including exhaustive and random value ranges testing. +* Lots of corner-case bugfixes, particularly when decoding partial or corrupted data. +* The minimum rust version is now 1.32. +* Raised `mqttrs`'s bus factor to 2 ;) + +# 0.1.4 (2019-09-16) + +* Fix issue #8: Decoding an incomplete packet still consumes bytes from the buffer. diff --git a/Cargo.toml b/Cargo.toml index 8337346..fedacc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,8 @@ -cargo-features = ["edition"] - [package] name = "mqttrs" version = "0.1.4" -authors = ["00imvj00 "] +authors = ["00imvj00 ", + "Vincent de Phily "] edition = "2018" description = "mqttrs is encoding & decoding library for mqtt protocol, it can work with both sync as well as async apps" homepage = "https://github.com/00imvj00/mqttrs" @@ -11,7 +10,6 @@ repository = "https://github.com/00imvj00/mqttrs" keywords = ["mqtt", "encoding", "decoding", "async", "async-mqtt"] license = "Apache-2.0" - [dependencies] bytes = "0.4" diff --git a/README.md b/README.md index b677a9d..9dee2e6 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,45 @@ -# Rust Mqtt Encoding & Decoding - -### What is Mqtt? -MQTT is an ISO standard publish-subscribe-based messaging protocol. It works on top of the TCP/IP protocol. - -### What is Rust? -Rust is a multi-paradigm systems programming language focused on safety, especially safe concurrency. Rust is syntactically similar to C++, but is designed to provide better memory safety while maintaining high performance. - -### What is mqttrs? - -It is library which can be used in any rust projects where you need to transform valid mqtt bytes buffer to Mqtt types and vice versa. - -In short it is encoding/decoding library which you can use it in sync as well as async environment. - -The way it works is, It will take byte buffer as input and then will try to read the header of the mqtt packet, if the packet is not completely received as it happens in async networking, the library function will return `None` and will not remove any bytes from buffer. - -Once, the whole mqtt packet is received, mqttrs will convert the bytes into appropriate mqtt packet type and return as well as remove all bytes from the beginning which belongs to already received packet. - -So, in this way, this library can be used for sync tcp streams as well as async streams like tokio tcp streams. - +# Rust Mqtt Encoding & Decoding [![Crates.io](https://img.shields.io/crates/l/mqttrs)](LICENSE) [![Docs.rs](https://docs.rs/mqttrs/badge.svg)](https://docs.rs/mqttrs/*/mqttrs/) + +`Mqttrs` is a [Rust](https://www.rust-lang.org/) crate (library) to write [MQTT +protocol](https://mqtt.org/) clients and servers. + +It is a codec-only library with [very few dependencies](Cargo.toml) and a [straightworward and +composable API](https://docs.rs/mqttrs/*/mqttrs/), usable with rust's standard library or with async +frameworks like [tokio](https://tokio.rs/). + +`Mqttrs` currently requires [Rust >= 1.32](https://www.rust-lang.org/learn/get-started) and supports +[MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html). Support for [MQTT +5](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html) is planned for a future version. + +## Usage + +Add `mqttrs = "0.2"` to your `Cargo.toml`. + +```rust +use mqttrs::*; +use bytes::BytesMut; + +// Allocate write buffer. +let mut buf = BytesMut::with_capacity(1024); + +// Encode an MQTT Connect packet. +let pkt = Packet::Connect(Connect { protocol: Protocol::MQTT311, + keep_alive: 30, + client_id: "doc_client".into(), + clean_session: true, + last_will: None, + username: None, + password: None }); +assert!(encode(&pkt, &mut buf).is_ok()); +assert_eq!(&buf[14..], "doc_client".as_bytes()); +let mut encoded = buf.clone(); + +// Decode one packet. The buffer will advance to the next packet. +assert_eq!(Ok(Some(pkt)), decode(&mut buf)); + +// Example decode failures. +let mut incomplete = encoded.split_to(10); +assert_eq!(Ok(None), decode(&mut incomplete)); +let mut garbage = BytesMut::from(vec![0u8,0,0,0]); +assert_eq!(Err(Error::InvalidHeader), decode(&mut garbage)); +``` \ No newline at end of file diff --git a/src/codec_test.rs b/src/codec_test.rs index 0d8cf31..526a26a 100644 --- a/src/codec_test.rs +++ b/src/codec_test.rs @@ -1,4 +1,5 @@ use crate::*; +use bytes::BufMut; use bytes::BytesMut; use proptest::{bool, collection::vec, num::*, prelude::*}; @@ -13,6 +14,11 @@ prop_compose! { QoS::from_u8(qos).unwrap() } } +prop_compose! { + fn stg_pid()(pid in 1..std::u16::MAX) -> Pid { + Pid::try_from(pid).unwrap() + } +} prop_compose! { fn stg_subtopic()(topic_path in stg_topic(), qos in stg_qos()) -> SubscribeTopic { SubscribeTopic { topic_path, qos } @@ -40,13 +46,13 @@ prop_compose! { clean_session in bool::ANY, username in stg_optstr(), password in stg_optstr()) -> Packet { - Packet::Connect(Connect { protocol: Protocol::MQTT(4), + Packet::Connect(Connect { protocol: Protocol::MQTT311, keep_alive, client_id, clean_session, last_will: None, username, - password }) + password: password.map(|p| p.as_bytes().to_vec()) }) } } prop_compose! { @@ -57,67 +63,70 @@ prop_compose! { } prop_compose! { fn stg_publish()(dup in bool::ANY, - qos in 0u8..3, - pid in u16::ANY, + qos in stg_qos(), + pid in stg_pid(), retain in bool::ANY, topic_name in stg_topic(), payload in vec(0u8..255u8, 1..300)) -> Packet { Packet::Publish(Publish{dup, - qos: QoS::from_u8(qos).unwrap(), + qospid: match qos { + QoS::AtMostOnce => QosPid::AtMostOnce, + QoS::AtLeastOnce => QosPid::AtLeastOnce(pid), + QoS::ExactlyOnce => QosPid::ExactlyOnce(pid), + }, retain, topic_name, - pid: if qos == 0 { None } else { Some(PacketIdentifier(pid)) }, payload}) } } prop_compose! { - fn stg_puback()(pid in u16::ANY) -> Packet { - Packet::Puback(PacketIdentifier(pid)) + fn stg_puback()(pid in stg_pid()) -> Packet { + Packet::Puback(pid) } } prop_compose! { - fn stg_pubrec()(pid in u16::ANY) -> Packet { - Packet::Pubrec(PacketIdentifier(pid)) + fn stg_pubrec()(pid in stg_pid()) -> Packet { + Packet::Pubrec(pid) } } prop_compose! { - fn stg_pubrel()(pid in u16::ANY) -> Packet { - Packet::Puback(PacketIdentifier(pid)) + fn stg_pubrel()(pid in stg_pid()) -> Packet { + Packet::Pubrel(pid) } } prop_compose! { - fn stg_pubcomp()(pid in u16::ANY) -> Packet { - Packet::PubComp(PacketIdentifier(pid)) + fn stg_pubcomp()(pid in stg_pid()) -> Packet { + Packet::Pubcomp(pid) } } prop_compose! { - fn stg_subscribe()(pid in u16::ANY, topics in vec(stg_subtopic(), 0..20)) -> Packet { - Packet::Subscribe(Subscribe{pid: PacketIdentifier(pid), topics}) + fn stg_subscribe()(pid in stg_pid(), topics in vec(stg_subtopic(), 0..20)) -> Packet { + Packet::Subscribe(Subscribe{pid: pid, topics}) } } prop_compose! { - fn stg_suback()(pid in u16::ANY, return_codes in vec(stg_subretcode(), 0..300)) -> Packet { - Packet::SubAck(Suback{pid: PacketIdentifier(pid), return_codes}) + fn stg_suback()(pid in stg_pid(), return_codes in vec(stg_subretcode(), 0..300)) -> Packet { + Packet::Suback(Suback{pid: pid, return_codes}) } } prop_compose! { - fn stg_unsubscribe()(pid in u16::ANY, topics in vec(stg_topic(), 0..20)) -> Packet { - Packet::UnSubscribe(Unsubscribe{pid:PacketIdentifier(pid), topics}) + fn stg_unsubscribe()(pid in stg_pid(), topics in vec(stg_topic(), 0..20)) -> Packet { + Packet::Unsubscribe(Unsubscribe{pid:pid, topics}) } } prop_compose! { - fn stg_unsuback()(pid in u16::ANY) -> Packet { - Packet::UnSubAck(PacketIdentifier(pid)) + fn stg_unsuback()(pid in stg_pid()) -> Packet { + Packet::Unsuback(pid) } } prop_compose! { fn stg_pingreq()(_ in bool::ANY) -> Packet { - Packet::PingReq + Packet::Pingreq } } prop_compose! { fn stg_pingresp()(_ in bool::ANY) -> Packet { - Packet::PingResp + Packet::Pingresp } } prop_compose! { @@ -136,14 +145,14 @@ macro_rules! impl_proptests { fn $name(pkt in $stg()) { // Encode the packet let mut buf = BytesMut::with_capacity(10240); - let res = encoder::encode(&pkt.clone(), &mut buf); + let res = encode(&pkt, &mut buf); prop_assert!(res.is_ok(), "encode({:?}) -> {:?}", pkt, res); prop_assert!(buf.len() >= 2, "buffer too small: {:?}", buf); //PING is 2 bytes prop_assert!(buf[0] >> 4 > 0 && buf[0] >> 4 < 16, "bad packet type {:?}", buf); // Check that decoding returns the original - let mut encoded = buf.clone(); - let decoded = decoder::decode(&mut buf); + let encoded = buf.clone(); + let decoded = decode(&mut buf); let ok = match &decoded { Ok(Some(p)) if *p == pkt => true, _other => false, @@ -152,9 +161,23 @@ macro_rules! impl_proptests { prop_assert!(buf.is_empty(), "Buffer not empty: {:?}", buf); // Check that decoding a partial packet returns Ok(None) - encoded.split_off(encoded.len() - 1); - let decoded = decoder::decode(&mut encoded).unwrap(); + let decoded = decode(&mut encoded.clone().split_off(encoded.len() - 1)).unwrap(); prop_assert!(decoded.is_none(), "partial decode {:?} -> {:?}", encoded, decoded); + + // Check that encoding into a small buffer fails cleanly + buf.clear(); + buf.split_off(encoded.len()); + prop_assert!(encoded.len() == buf.remaining_mut() && buf.is_empty(), + "Wrong buffer init1 {}/{}/{}", encoded.len(), buf.remaining_mut(), buf.is_empty()); + prop_assert!(encode(&pkt, &mut buf).is_ok(), "exact buffer capacity {}", buf.capacity()); + for l in (0..encoded.len()).rev() { + buf.clear(); + buf.split_to(1); + prop_assert!(l == buf.remaining_mut() && buf.is_empty(), + "Wrong buffer init2 {}/{}/{}", l, buf.remaining_mut(), buf.is_empty()); + prop_assert_eq!(Err(Error::WriteZero), encode(&pkt, &mut buf), + "small buffer capacity {}/{}", buf.capacity(), encoded.len()); + } } } }; diff --git a/src/connect.rs b/src/connect.rs index dd6983b..75cb7f7 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,7 +1,100 @@ -use crate::{encoder, utils, ConnectReturnCode, LastWill, Protocol, QoS}; +use crate::{decoder::*, encoder::*, *}; use bytes::{Buf, BufMut, BytesMut, IntoBuf}; -use std::io; +/// Protocol version. +/// +/// Sent in [`Connect`] packet. +/// +/// [`Connect`]: struct.Connect.html +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Protocol { + /// [MQTT 3.1.1] is the most commonly implemented version. [MQTT 5] isn't yet supported my by + /// `mqttrs`. + /// + /// [MQTT 3.1.1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html + /// [MQTT 5]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html + MQTT311, + /// MQIsdp, aka SCADA are pre-standardisation names of MQTT. It should mostly conform to MQTT + /// 3.1.1, but you should watch out for implementation discrepancies. `Mqttrs` handles it like + /// standard MQTT 3.1.1. + MQIsdp, +} +impl Protocol { + pub(crate) fn new(name: &str, level: u8) -> Result { + match (name, level) { + ("MQIsdp", 3) => Ok(Protocol::MQIsdp), + ("MQTT", 4) => Ok(Protocol::MQTT311), + _ => Err(Error::InvalidProtocol(name.into(), level)), + } + } + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { + match self { + Protocol::MQTT311 => { + Ok(buffer.put_slice(&[0u8, 4, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 4])) + } + Protocol::MQIsdp => Ok(buffer.put_slice(&[ + 0u8, 4, 'M' as u8, 'Q' as u8, 'i' as u8, 's' as u8, 'd' as u8, 'p' as u8, 4, + ])), + } + } +} + +/// Message that the server should publish when the client disconnects. +/// +/// Sent by the client in the [Connect] packet. [MQTT 3.1.3.3]. +/// +/// [Connect]: struct.Connect.html +/// [MQTT 3.1.3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031 +#[derive(Debug, Clone, PartialEq)] +pub struct LastWill { + pub topic: String, + pub message: Vec, + pub qos: QoS, + pub retain: bool, +} + +/// Sucess value of a [Connack] packet. +/// +/// See [MQTT 3.2.2.3] for interpretations. +/// +/// [Connack]: struct.Connack.html +/// [MQTT 3.2.2.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035 +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ConnectReturnCode { + Accepted, + RefusedProtocolVersion, + RefusedIdentifierRejected, + ServerUnavailable, + BadUsernamePassword, + NotAuthorized, +} +impl ConnectReturnCode { + fn to_u8(&self) -> u8 { + match *self { + ConnectReturnCode::Accepted => 0, + ConnectReturnCode::RefusedProtocolVersion => 1, + ConnectReturnCode::RefusedIdentifierRejected => 2, + ConnectReturnCode::ServerUnavailable => 3, + ConnectReturnCode::BadUsernamePassword => 4, + ConnectReturnCode::NotAuthorized => 5, + } + } + pub(crate) fn from_u8(byte: u8) -> Result { + match byte { + 0 => Ok(ConnectReturnCode::Accepted), + 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), + 2 => Ok(ConnectReturnCode::RefusedIdentifierRejected), + 3 => Ok(ConnectReturnCode::ServerUnavailable), + 4 => Ok(ConnectReturnCode::BadUsernamePassword), + 5 => Ok(ConnectReturnCode::NotAuthorized), + n => Err(Error::InvalidConnectReturnCode(n)), + } + } +} + +/// Connect packet ([MQTT 3.1]). +/// +/// [MQTT 3.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028 #[derive(Debug, Clone, PartialEq)] pub struct Connect { pub protocol: Protocol, @@ -10,9 +103,12 @@ pub struct Connect { pub clean_session: bool, pub last_will: Option, pub username: Option, - pub password: Option, + pub password: Option>, } +/// Connack packet ([MQTT 3.2]). +/// +/// [MQTT 3.2]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033 #[derive(Debug, Clone, Copy, PartialEq)] pub struct Connack { pub session_present: bool, @@ -20,19 +116,19 @@ pub struct Connack { } impl Connect { - pub fn from_buffer(buffer: &mut BytesMut) -> Result { - let protocol_name = utils::read_string(buffer); + pub(crate) fn from_buffer(buffer: &mut BytesMut) -> Result { + let protocol_name = read_string(buffer)?; let protocol_level = buffer.split_to(1).into_buf().get_u8(); let protocol = Protocol::new(&protocol_name, protocol_level).unwrap(); let connect_flags = buffer.split_to(1).into_buf().get_u8(); let keep_alive = buffer.split_to(2).into_buf().get_u16_be(); - let client_id = utils::read_string(buffer); + let client_id = read_string(buffer)?; let last_will = if connect_flags & 0b100 != 0 { - let will_topic = utils::read_string(buffer); - let will_message = utils::read_string(buffer); + let will_topic = read_string(buffer)?; + let will_message = read_bytes(buffer)?; let will_qod = QoS::from_u8((connect_flags & 0b11000) >> 3).unwrap(); Some(LastWill { topic: will_topic, @@ -45,13 +141,13 @@ impl Connect { }; let username = if connect_flags & 0b10000000 != 0 { - Some(utils::read_string(buffer)) + Some(read_string(buffer)?) } else { None }; let password = if connect_flags & 0b01000000 != 0 { - Some(utils::read_string(buffer)) + Some(read_bytes(buffer)?) } else { None }; @@ -68,7 +164,7 @@ impl Connect { clean_session, }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error> { + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { let header_u8: u8 = 0b00010000; let mut length: usize = 6 + 1 + 1; //NOTE: protocol_name(6) + protocol_level(1) + flags(1); let mut connect_flags: u8 = 0b00000000; @@ -97,26 +193,26 @@ impl Connect { length += last_will.topic.len(); length += 4; }; + check_remaining(buffer, length + 1)?; //NOTE: putting data into buffer. buffer.put(header_u8); - encoder::write_length(length, buffer)?; - encoder::write_string(self.protocol.name(), buffer)?; - buffer.put(self.protocol.level()); + write_length(length, buffer)?; + self.protocol.to_buffer(buffer)?; buffer.put(connect_flags); buffer.put_u16_be(self.keep_alive); - encoder::write_string(self.client_id.as_ref(), buffer)?; + write_string(self.client_id.as_ref(), buffer)?; if let Some(last_will) = &self.last_will { - encoder::write_string(last_will.topic.as_ref(), buffer)?; - encoder::write_string(last_will.message.as_ref(), buffer)?; + write_string(last_will.topic.as_ref(), buffer)?; + write_bytes(&last_will.message, buffer)?; }; if let Some(username) = &self.username { - encoder::write_string(username.as_ref(), buffer)?; + write_string(username.as_ref(), buffer)?; }; if let Some(password) = &self.password { - encoder::write_string(password.as_ref(), buffer)?; + write_bytes(password, buffer)?; }; //NOTE: END Ok(()) @@ -124,7 +220,7 @@ impl Connect { } impl Connack { - pub fn from_buffer(buffer: &mut BytesMut) -> Result { + pub(crate) fn from_buffer(buffer: &mut BytesMut) -> Result { let flags = buffer.split_to(1).into_buf().get_u8(); let return_code = buffer.split_to(1).into_buf().get_u8(); Ok(Connack { @@ -132,7 +228,8 @@ impl Connack { code: ConnectReturnCode::from_u8(return_code)?, }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error> { + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { + check_remaining(buffer, 4)?; let header_u8 = 0b00100000 as u8; let length = 2 as u8; let mut flags = 0b00000000 as u8; diff --git a/src/decoder.rs b/src/decoder.rs index cd8e3c4..d8acbc5 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -1,98 +1,241 @@ -use crate::MULTIPLIER; use crate::*; use bytes::{Buf, BytesMut, IntoBuf}; -use std::io; -#[allow(dead_code)] -pub fn decode(buffer: &mut BytesMut) -> Result, io::Error> { - if let Some((header, header_size)) = read_header(buffer) { - if buffer.len() >= header.len() + header_size { - //NOTE: Check if buffer has, header bytes + remaining length bytes in buffer. - buffer.split_to(header_size); //NOTE: Remove header bytes from buffer. - let p = read_packet(header, buffer)?; //NOTE: Read remaining packet. - Ok(Some(p)) - } else { - Ok(None) - } +/// Decode bytes from a [BytesMut] buffer as a [Packet] enum. +/// +/// ``` +/// # use mqttrs::*; +/// # use bytes::*; +/// // Fill a buffer with encoded data (probably from a `TcpStream`). +/// let mut buf = BytesMut::from(vec![0b00110000, 11, +/// 0, 4, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, +/// 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8]); +/// +/// // Parse the bytes and check the result. +/// match decode(&mut buf) { +/// Ok(Some(Packet::Publish(p))) => { +/// assert_eq!(p.payload, "hello".as_bytes().to_vec()); +/// }, +/// // In real code you probably don't want to panic like that ;) +/// Ok(None) => panic!("not enough data"), +/// other => panic!("unexpected {:?}", other), +/// } +/// ``` +/// +/// [Packet]: ../enum.Packet.html +/// [BytesMut]: https://docs.rs/bytes/0.4.12/bytes/struct.BytesMut.html +pub fn decode(buffer: &mut BytesMut) -> Result, Error> { + if let Some((header, remaining_len)) = read_header(buffer)? { + // Advance the buffer position to the next packet, and parse the current packet + let p = &mut buffer.split_to(remaining_len); + Ok(Some(read_packet(header, p)?)) } else { + // Don't have a full packet Ok(None) } } -fn read_packet(header: Header, buffer: &mut BytesMut) -> Result { - let t = header.packet(); - match t { - PacketType::PingReq => Ok(Packet::PingReq), - PacketType::PingResp => Ok(Packet::PingResp), - PacketType::Disconnect => Ok(Packet::Disconnect), - PacketType::Connect => Ok(Packet::Connect(Connect::from_buffer( - &mut buffer.split_to(header.len()), - )?)), - PacketType::Connack => Ok(Packet::Connack(Connack::from_buffer( - &mut buffer.split_to(header.len()), - )?)), - PacketType::Publish => Ok(Packet::Publish(Publish::from_buffer( - &header, - &mut buffer.split_to(header.len()), - )?)), - PacketType::Puback => Ok(Packet::Puback(PacketIdentifier( - buffer.split_to(2).into_buf().get_u16_be(), - ))), - PacketType::Pubrec => Ok(Packet::Pubrec(PacketIdentifier( - buffer.split_to(2).into_buf().get_u16_be(), - ))), - PacketType::Pubrel => Ok(Packet::Pubrel(PacketIdentifier( - buffer.split_to(2).into_buf().get_u16_be(), - ))), - PacketType::PubComp => Ok(Packet::PubComp(PacketIdentifier( - buffer.split_to(2).into_buf().get_u16_be(), - ))), - PacketType::Subscribe => Ok(Packet::Subscribe(Subscribe::from_buffer( - &mut buffer.split_to(header.len()), - )?)), - PacketType::SubAck => Ok(Packet::SubAck(Suback::from_buffer( - &mut buffer.split_to(header.len()), - )?)), - PacketType::UnSubscribe => Ok(Packet::UnSubscribe(Unsubscribe::from_buffer( - &mut buffer.split_to(header.len()), - )?)), - PacketType::UnSubAck => Ok(Packet::UnSubAck(PacketIdentifier( - buffer.split_to(2).into_buf().get_u16_be(), - ))), - } +fn read_packet(header: Header, buffer: &mut BytesMut) -> Result { + Ok(match header.typ { + PacketType::Pingreq => Packet::Pingreq, + PacketType::Pingresp => Packet::Pingresp, + PacketType::Disconnect => Packet::Disconnect, + PacketType::Connect => Connect::from_buffer(buffer)?.into(), + PacketType::Connack => Connack::from_buffer(buffer)?.into(), + PacketType::Publish => Publish::from_buffer(&header, buffer)?.into(), + PacketType::Puback => Packet::Puback(Pid::from_buffer(buffer)?), + PacketType::Pubrec => Packet::Pubrec(Pid::from_buffer(buffer)?), + PacketType::Pubrel => Packet::Pubrel(Pid::from_buffer(buffer)?), + PacketType::Pubcomp => Packet::Pubcomp(Pid::from_buffer(buffer)?), + PacketType::Subscribe => Subscribe::from_buffer(buffer)?.into(), + PacketType::Suback => Suback::from_buffer(buffer)?.into(), + PacketType::Unsubscribe => Unsubscribe::from_buffer(buffer)?.into(), + PacketType::Unsuback => Packet::Unsuback(Pid::from_buffer(buffer)?), + }) } -/* This will read the header of the stream */ -fn read_header(buffer: &mut BytesMut) -> Option<(Header, usize)> { - if buffer.len() > 1 { - let header_u8 = buffer.get(0).unwrap(); - if let Some((length, size)) = read_length(buffer, 1) { - let header = Header::new(*header_u8, length).unwrap(); - Some((header, size + 1)) + +/// Read the parsed header and remaining_len from the buffer. Only return Some() and advance the +/// buffer position if there is enough data in th ebuffer to read the full packet. +fn read_header(buffer: &mut BytesMut) -> Result, Error> { + let mut len: usize = 0; + for pos in 0..=3 { + if let Some(&byte) = buffer.get(pos + 1) { + len += (byte as usize & 0x7F) << (pos * 7); + if (byte & 0x80) == 0 { + // Continuation bit == 0, length is parsed + if buffer.len() < 2 + pos + len { + // Won't be able to read full packet + return Ok(None); + } + // Parse header byte, skip past the header, and return + let header = Header::new(*buffer.get(0).unwrap())?; + buffer.advance(pos + 2); + return Ok(Some((header, len))); + } } else { - None + // Couldn't read full length + return Ok(None); } + } + // Continuation byte == 1 four times, that's illegal. + Err(Error::InvalidHeader) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct Header { + pub typ: PacketType, + pub dup: bool, + pub qos: QoS, + pub retain: bool, +} +impl Header { + pub fn new(hd: u8) -> Result { + let (typ, flags_ok) = match hd >> 4 { + 1 => (PacketType::Connect, hd & 0b1111 == 0), + 2 => (PacketType::Connack, hd & 0b1111 == 0), + 3 => (PacketType::Publish, true), + 4 => (PacketType::Puback, hd & 0b1111 == 0), + 5 => (PacketType::Pubrec, hd & 0b1111 == 0), + 6 => (PacketType::Pubrel, hd & 0b1111 == 0b0010), + 7 => (PacketType::Pubcomp, hd & 0b1111 == 0), + 8 => (PacketType::Subscribe, hd & 0b1111 == 0b0010), + 9 => (PacketType::Suback, hd & 0b1111 == 0), + 10 => (PacketType::Unsubscribe, hd & 0b1111 == 0b0010), + 11 => (PacketType::Unsuback, hd & 0b1111 == 0), + 12 => (PacketType::Pingreq, hd & 0b1111 == 0), + 13 => (PacketType::Pingresp, hd & 0b1111 == 0), + 14 => (PacketType::Disconnect, hd & 0b1111 == 0), + _ => (PacketType::Connect, false), + }; + if !flags_ok { + return Err(Error::InvalidHeader); + } + Ok(Header { + typ, + dup: hd & 0b1000 != 0, + qos: QoS::from_u8((hd & 0b110) >> 1)?, + retain: hd & 1 == 1, + }) + } +} + +pub(crate) fn read_string(buffer: &mut BytesMut) -> Result { + String::from_utf8(read_bytes(buffer)?).map_err(|e| Error::InvalidString(e.utf8_error())) +} + +pub(crate) fn read_bytes(buffer: &mut BytesMut) -> Result, Error> { + let len = buffer.split_to(2).into_buf().get_u16_be() as usize; + if len > buffer.len() { + Err(Error::InvalidLength) } else { - None + Ok(buffer.split_to(len).to_vec()) } } -fn read_length(buffer: &BytesMut, mut pos: usize) -> Option<(usize, usize)> { - let mut mult: usize = 1; - let mut len: usize = 0; - let mut done = false; +#[cfg(test)] +mod test { + use crate::decoder::*; + use bytes::BytesMut; - while !done { - let byte = (*buffer.get(pos).unwrap()) as usize; - len += (byte & 0x7F) * mult; - mult *= 0x80; - if mult > MULTIPLIER { - return None; + macro_rules! header { + ($t:ident, $d:expr, $q:ident, $r:expr) => { + Header { + typ: PacketType::$t, + dup: $d, + qos: QoS::$q, + retain: $r, + } + }; + } + + /// Test all possible header first byte, using remaining_len=0. + #[test] + fn header_firstbyte() { + let valid = vec![ + (0b0001_0000, header!(Connect, false, AtMostOnce, false)), + (0b0010_0000, header!(Connack, false, AtMostOnce, false)), + (0b0011_0000, header!(Publish, false, AtMostOnce, false)), + (0b0011_0001, header!(Publish, false, AtMostOnce, true)), + (0b0011_0010, header!(Publish, false, AtLeastOnce, false)), + (0b0011_0011, header!(Publish, false, AtLeastOnce, true)), + (0b0011_0100, header!(Publish, false, ExactlyOnce, false)), + (0b0011_0101, header!(Publish, false, ExactlyOnce, true)), + (0b0011_1000, header!(Publish, true, AtMostOnce, false)), + (0b0011_1001, header!(Publish, true, AtMostOnce, true)), + (0b0011_1010, header!(Publish, true, AtLeastOnce, false)), + (0b0011_1011, header!(Publish, true, AtLeastOnce, true)), + (0b0011_1100, header!(Publish, true, ExactlyOnce, false)), + (0b0011_1101, header!(Publish, true, ExactlyOnce, true)), + (0b0100_0000, header!(Puback, false, AtMostOnce, false)), + (0b0101_0000, header!(Pubrec, false, AtMostOnce, false)), + (0b0110_0010, header!(Pubrel, false, AtLeastOnce, false)), + (0b0111_0000, header!(Pubcomp, false, AtMostOnce, false)), + (0b1000_0010, header!(Subscribe, false, AtLeastOnce, false)), + (0b1001_0000, header!(Suback, false, AtMostOnce, false)), + (0b1010_0010, header!(Unsubscribe, false, AtLeastOnce, false)), + (0b1011_0000, header!(Unsuback, false, AtMostOnce, false)), + (0b1100_0000, header!(Pingreq, false, AtMostOnce, false)), + (0b1101_0000, header!(Pingresp, false, AtMostOnce, false)), + (0b1110_0000, header!(Disconnect, false, AtMostOnce, false)), + ]; + for n in 0..=255 { + let res = match valid.iter().find(|(byte, _)| *byte == n) { + Some((_, header)) => Ok(Some((*header, 0))), + None if ((n & 0b110) == 0b110) && (n >> 4 == 3) => Err(Error::InvalidQos(3)), + None => Err(Error::InvalidHeader), + }; + let buf = &mut BytesMut::from(vec![n, 0]); + assert_eq!(res, read_header(buf), "{:08b}", n); } - if (byte & 0x80) == 0 { - done = true; - } else { - pos += 1; + } + + /// Test decoding of length and actual buffer len. + #[rustfmt::skip] + #[test] + fn header_len() { + let h = header!(Connect, false, AtMostOnce, false); + for (res, bytes, buflen) in vec![ + (Ok(Some((h, 0))), vec![1 << 4, 0], 2), + (Ok(None), vec![1 << 4, 127], 128), + (Ok(Some((h, 127))), vec![1 << 4, 127], 129), + (Ok(None), vec![1 << 4, 0x80], 2), + (Ok(Some((h, 0))), vec![1 << 4, 0x80, 0], 3), //Weird encoding for "0" buf matches spec + (Ok(Some((h, 128))), vec![1 << 4, 0x80, 1], 131), + (Ok(None), vec![1 << 4, 0x80+16, 78], 10002), + (Ok(Some((h, 10000))), vec![1 << 4, 0x80+16, 78], 10003), + (Err(Error::InvalidHeader), vec![1 << 4, 0x80, 0x80, 0x80, 0x80], 10), + ] { + let mut buf = BytesMut::from(bytes); + buf.resize(buflen, 0); + assert_eq!(res, read_header(&mut buf)); } } - Some((len as usize, pos)) + + #[test] + fn non_utf8_string() { + let mut data = BytesMut::from(vec![ + 0b00110000, 10, // type=Publish, remaining_len=10 + 0x00, 0x03, 'a' as u8, '/' as u8, 0xc0 as u8, // Topic with Invalid utf8 + 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, // payload + ]); + assert!(match decode(&mut data) { + Err(Error::InvalidString(_)) => true, + _ => false, + }); + } + + /// Validity of remaining_len is tested exhaustively elsewhere, this is for inner lengths, which + /// are rarer. + #[test] + fn inner_length_too_long() { + let mut data = BytesMut::from(vec![ + 0b00010000, 20, // Connect packet, remaining_len=20 + 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, + 0b01000000, // +password + 0x00, 0x0a, // keepalive 10 sec + 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id + 0x00, 0x03, 'm' as u8, 'q' as u8, // password with invalid length + ]); + assert_eq!(Err(Error::InvalidLength), decode(&mut data)); + } } diff --git a/src/decoder_test.rs b/src/decoder_test.rs index 7b01b5b..45c39e7 100644 --- a/src/decoder_test.rs +++ b/src/decoder_test.rs @@ -1,9 +1,4 @@ -#![allow(unused_imports)] - -use crate::{ - decoder, Connack, ConnectReturnCode, Packet, PacketIdentifier, QoS, SubscribeReturnCodes, - SubscribeTopic, -}; +use crate::*; use bytes::BytesMut; #[test] @@ -20,10 +15,8 @@ fn test_half_connect() { // 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' // 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' ]); - let length = data.len(); - let d = decoder::decode(&mut data).unwrap(); - assert_eq!(d, None); - assert_eq!(length, 12); + assert_eq!(Ok(None), decode(&mut data)); + assert_eq!(12, data.len()); } #[test] @@ -39,8 +32,21 @@ fn test_connect() { 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' ]); - let d = decoder::decode(&mut data).unwrap(); - assert_ne!(d, None); + let pkt = Connect { + protocol: Protocol::MQTT311, + keep_alive: 10, + client_id: "test".into(), + clean_session: true, + last_will: Some(LastWill { + topic: "/a".into(), + message: "offline".into(), + qos: QoS::AtLeastOnce, + retain: false, + }), + username: Some("rust".into()), + password: Some("mq".into()), + }; + assert_eq!(Ok(Some(pkt.into())), decode(&mut data)); assert_eq!(data.len(), 0); } @@ -65,22 +71,19 @@ fn test_connack() { #[test] fn test_ping_req() { let mut data = BytesMut::from(vec![0b11000000, 0b00000000]); - let d = decoder::decode(&mut data).unwrap(); - assert_eq!(d, Some(Packet::PingReq)); + assert_eq!(Ok(Some(Packet::Pingreq)), decode(&mut data)); } #[test] fn test_ping_resp() { let mut data = BytesMut::from(vec![0b11010000, 0b00000000]); - let d = decoder::decode(&mut data).unwrap(); - assert_eq!(d, Some(Packet::PingResp)); + assert_eq!(Ok(Some(Packet::Pingresp)), decode(&mut data)); } #[test] fn test_disconnect() { let mut data = BytesMut::from(vec![0b11100000, 0b00000000]); - let d = decoder::decode(&mut data).unwrap(); - assert_eq!(d, Some(Packet::Disconnect)); + assert_eq!(Ok(Some(Packet::Disconnect)), decode(&mut data)); } #[test] @@ -93,90 +96,73 @@ fn test_publish() { 0b00111101, 12, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0 as u8, 10 as u8, 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, ]); - let d1 = decoder::decode(&mut data).unwrap(); - let d2 = decoder::decode(&mut data).unwrap(); - let d3 = decoder::decode(&mut data).unwrap(); - println!("{:?}", d1); - match d1 { - Some(Packet::Publish(p)) => { + match decode(&mut data) { + Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, false); assert_eq!(p.retain, false); - assert_eq!(p.qos, QoS::AtMostOnce); + assert_eq!(p.qospid, QosPid::AtMostOnce); assert_eq!(p.topic_name, "a/b"); assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); } - _ => panic!("Should not be None"), + other => panic!("Failed decode: {:?}", other), } - match d2 { - Some(Packet::Publish(p)) => { + match decode(&mut data) { + Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, true); assert_eq!(p.retain, false); - assert_eq!(p.qos, QoS::AtMostOnce); + assert_eq!(p.qospid, QosPid::AtMostOnce); assert_eq!(p.topic_name, "a/b"); assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); } - _ => panic!("Should not be None"), + other => panic!("Failed decode: {:?}", other), } - match d3 { - Some(Packet::Publish(p)) => { + match decode(&mut data) { + Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, true); assert_eq!(p.retain, true); - assert_eq!(p.qos, QoS::ExactlyOnce); + assert_eq!(p.qospid, QosPid::from_u8u16(2, 10)); assert_eq!(p.topic_name, "a/b"); - assert_eq!(p.pid.unwrap(), PacketIdentifier(10)); assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); } - _ => panic!("Should not be None"), + other => panic!("Failed decode: {:?}", other), } } #[test] fn test_pub_ack() { let mut data = BytesMut::from(vec![0b01000000, 0b00000010, 0 as u8, 10 as u8]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::Puback(a)) => { - assert_eq!(a.0, 10); - } - _ => panic!(), - } + match decode(&mut data) { + Ok(Some(Packet::Puback(a))) => assert_eq!(a.get(), 10), + other => panic!("Failed decode: {:?}", other), + }; } #[test] fn test_pub_rec() { let mut data = BytesMut::from(vec![0b01010000, 0b00000010, 0 as u8, 10 as u8]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::Pubrec(a)) => { - assert_eq!(a.0, 10); - } - _ => panic!(), - } + match decode(&mut data) { + Ok(Some(Packet::Pubrec(a))) => assert_eq!(a.get(), 10), + other => panic!("Failed decode: {:?}", other), + }; } #[test] fn test_pub_rel() { let mut data = BytesMut::from(vec![0b01100010, 0b00000010, 0 as u8, 10 as u8]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::Pubrel(a)) => { - assert_eq!(a.0, 10); - } - _ => panic!(), - } + match decode(&mut data) { + Ok(Some(Packet::Pubrel(a))) => assert_eq!(a.get(), 10), + other => panic!("Failed decode: {:?}", other), + }; } #[test] fn test_pub_comp() { let mut data = BytesMut::from(vec![0b01110000, 0b00000010, 0 as u8, 10 as u8]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::PubComp(a)) => { - assert_eq!(a.0, 10); - } - _ => panic!(), - } + match decode(&mut data) { + Ok(Some(Packet::Pubcomp(a))) => assert_eq!(a.get(), 10), + other => panic!("Failed decode: {:?}", other), + }; } #[test] @@ -185,17 +171,16 @@ fn test_subscribe() { 0b10000010, 8, 0 as u8, 10 as u8, 0 as u8, 3 as u8, 'a' as u8, '/' as u8, 'b' as u8, 0 as u8, ]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::Subscribe(s)) => { - assert_eq!(s.pid.0, 10); + match decode(&mut data) { + Ok(Some(Packet::Subscribe(s))) => { + assert_eq!(s.pid.get(), 10); let t = SubscribeTopic { topic_path: "a/b".to_string(), qos: QoS::AtMostOnce, }; assert_eq!(s.topics[0], t); } - _ => panic!(), + other => panic!("Failed decode: {:?}", other), } } @@ -203,16 +188,15 @@ fn test_subscribe() { fn test_suback() { let mut data = BytesMut::from(vec![0b10010000, 3, 0 as u8, 10 as u8, 0b00000010]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::SubAck(s)) => { - assert_eq!(s.pid.0, 10); + match decode(&mut data) { + Ok(Some(Packet::Suback(s))) => { + assert_eq!(s.pid.get(), 10); assert_eq!( s.return_codes[0], SubscribeReturnCodes::Success(QoS::ExactlyOnce) ); } - _ => panic!(), + other => panic!("Failed decode: {:?}", other), } } @@ -221,24 +205,22 @@ fn test_unsubscribe() { let mut data = BytesMut::from(vec![ 0b10100010, 5, 0 as u8, 10 as u8, 0 as u8, 1 as u8, 'a' as u8, ]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::UnSubscribe(a)) => { - assert_eq!(a.pid.0, 10); + match decode(&mut data) { + Ok(Some(Packet::Unsubscribe(a))) => { + assert_eq!(a.pid.get(), 10); assert_eq!(a.topics[0], 'a'.to_string()); } - _ => panic!(), + other => panic!("Failed decode: {:?}", other), } } #[test] fn test_unsub_ack() { let mut data = BytesMut::from(vec![0b10110000, 2, 0 as u8, 10 as u8]); - let d = decoder::decode(&mut data).unwrap(); - match d { - Some(Packet::UnSubAck(p)) => { - assert_eq!(p.0, 10); + match decode(&mut data) { + Ok(Some(Packet::Unsuback(p))) => { + assert_eq!(p.get(), 10); } - _ => panic!(), + other => panic!("Failed decode: {:?}", other), } } diff --git a/src/encoder.rs b/src/encoder.rs index 96ab7d5..ea5a74a 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -1,67 +1,94 @@ -use crate::{Packet, MAX_PAYLOAD_SIZE}; +use crate::{Error, Packet}; use bytes::{BufMut, BytesMut}; -use std::io; -#[allow(dead_code)] -pub fn encode(packet: &Packet, buffer: &mut BytesMut) -> Result<(), io::Error> { +/// Encode a [Packet] enum into a [BytesMut] buffer. +/// +/// ``` +/// # use mqttrs::*; +/// # use bytes::*; +/// // Instantiate a `Packet` to encode. +/// let packet = Publish { +/// dup: false, +/// qospid: QosPid::AtMostOnce, +/// retain: false, +/// topic_name: "test".into(), +/// payload: "hello".into(), +/// }.into(); +/// +/// // Allocate a appropriately-sized buffer. +/// let mut buf = BytesMut::with_capacity(1024); +/// +/// // Write bytes corresponding to `&Packet` into the `BytesMut`. +/// encode(&packet, &mut buf).expect("failed encoding"); +/// assert_eq!(&*buf, &[0b00110000, 11, +/// 0, 4, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, +/// 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8]); +/// ``` +/// +/// [Packet]: ../enum.Packet.html +/// [BytesMut]: https://docs.rs/bytes/0.4.12/bytes/struct.BytesMut.html +pub fn encode(packet: &Packet, buffer: &mut BytesMut) -> Result<(), Error> { match packet { Packet::Connect(connect) => connect.to_buffer(buffer), Packet::Connack(connack) => connack.to_buffer(buffer), Packet::Publish(publish) => publish.to_buffer(buffer), Packet::Puback(pid) => { + check_remaining(buffer, 4)?; let header_u8 = 0b01000000 as u8; let length = 0b00000010 as u8; buffer.put(header_u8); buffer.put(length); - buffer.put_u16_be(pid.0); - Ok(()) + pid.to_buffer(buffer) } Packet::Pubrec(pid) => { + check_remaining(buffer, 4)?; let header_u8 = 0b01010000 as u8; let length = 0b00000010 as u8; buffer.put(header_u8); buffer.put(length); - buffer.put_u16_be(pid.0); - Ok(()) + pid.to_buffer(buffer) } Packet::Pubrel(pid) => { - let header_u8 = 0b01100000 as u8; + check_remaining(buffer, 4)?; + let header_u8 = 0b01100010 as u8; let length = 0b00000010 as u8; buffer.put(header_u8); buffer.put(length); - buffer.put_u16_be(pid.0); - Ok(()) + pid.to_buffer(buffer) } - Packet::PubComp(pid) => { + Packet::Pubcomp(pid) => { + check_remaining(buffer, 4)?; let header_u8 = 0b01110000 as u8; let length = 0b00000010 as u8; buffer.put(header_u8); buffer.put(length); - buffer.put_u16_be(pid.0); - Ok(()) + pid.to_buffer(buffer) } Packet::Subscribe(subscribe) => subscribe.to_buffer(buffer), - Packet::SubAck(suback) => suback.to_buffer(buffer), - Packet::UnSubscribe(unsub) => unsub.to_buffer(buffer), - Packet::UnSubAck(pid) => { + Packet::Suback(suback) => suback.to_buffer(buffer), + Packet::Unsubscribe(unsub) => unsub.to_buffer(buffer), + Packet::Unsuback(pid) => { + check_remaining(buffer, 4)?; let header_u8 = 0b10110000 as u8; let length = 0b00000010 as u8; buffer.put(header_u8); buffer.put(length); - buffer.put_u16_be(pid.0); - Ok(()) + pid.to_buffer(buffer) } - Packet::PingReq => { + Packet::Pingreq => { + check_remaining(buffer, 2)?; buffer.put(0b11000000 as u8); buffer.put(0b00000000 as u8); Ok(()) } - Packet::PingResp => { + Packet::Pingresp => { + check_remaining(buffer, 2)?; buffer.put(0b11010000 as u8); buffer.put(0b00000000 as u8); Ok(()) } Packet::Disconnect => { + check_remaining(buffer, 2)?; buffer.put(0b11100000 as u8); buffer.put(0b00000000 as u8); Ok(()) @@ -69,13 +96,25 @@ pub fn encode(packet: &Packet, buffer: &mut BytesMut) -> Result<(), io::Error> { } } -pub fn write_length(len: usize, buffer: &mut BytesMut) -> Result<(), io::Error> { - if len > MAX_PAYLOAD_SIZE { - return Err(io::Error::new( - io::ErrorKind::PermissionDenied, - "data size too big", - )); - }; +/// Check wether buffer has `len` bytes of write capacity left. Use this to return a clean +/// Result::Err instead of panicking. +pub(crate) fn check_remaining(buffer: &BytesMut, len: usize) -> Result<(), Error> { + if buffer.remaining_mut() < len { + Err(Error::WriteZero) + } else { + Ok(()) + } +} + +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718023 +pub(crate) fn write_length(len: usize, buffer: &mut BytesMut) -> Result<(), Error> { + match len { + 0..=127 => check_remaining(buffer, len + 1)?, + 128..=16383 => check_remaining(buffer, len + 2)?, + 16384..=2097151 => check_remaining(buffer, len + 3)?, + 2097152..=268435455 => check_remaining(buffer, len + 4)?, + _ => return Err(Error::InvalidLength), + } let mut done = false; let mut x = len; while !done { @@ -90,8 +129,12 @@ pub fn write_length(len: usize, buffer: &mut BytesMut) -> Result<(), io::Error> Ok(()) } -pub fn write_string(string: &str, buffer: &mut BytesMut) -> Result<(), io::Error> { - buffer.put_u16_be(string.len() as u16); - buffer.put_slice(string.as_bytes()); +pub(crate) fn write_bytes(bytes: &[u8], buffer: &mut BytesMut) -> Result<(), Error> { + buffer.put_u16_be(bytes.len() as u16); + buffer.put_slice(bytes); Ok(()) } + +pub(crate) fn write_string(string: &str, buffer: &mut BytesMut) -> Result<(), Error> { + write_bytes(string.as_bytes(), buffer) +} diff --git a/src/encoder_test.rs b/src/encoder_test.rs index 26aa7e9..4bab58e 100644 --- a/src/encoder_test.rs +++ b/src/encoder_test.rs @@ -1,10 +1,4 @@ -#[allow(unused_imports)] -use crate::{ - decoder, encoder, Connack, Connect, ConnectReturnCode, Packet, PacketIdentifier, Protocol, - Publish, QoS, Suback, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe, -}; - -#[allow(unused_imports)] +use crate::*; use bytes::BytesMut; #[test] @@ -19,13 +13,10 @@ fn test_connect() { password: None, }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::Connect(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Connect(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet.into(), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Connect(_))) => assert!(true), + err => assert!(false, err), } } @@ -36,13 +27,10 @@ fn test_connack() { code: ConnectReturnCode::Accepted, }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::Connack(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Connack(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet.into(), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Connack(_))) => assert!(true), + err => assert!(false, err), } } @@ -50,78 +38,60 @@ fn test_connack() { fn test_publish() { let packet = Publish { dup: false, - qos: QoS::ExactlyOnce, + qospid: QosPid::from_u8u16(2, 10), retain: true, topic_name: "asdf".to_string(), - pid: Some(PacketIdentifier(10)), payload: vec!['h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8], }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::Publish(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - println!("{:?}", decoded); - match decoded { - Some(Packet::Publish(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet.into(), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Publish(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_puback() { - let packet = Packet::Puback(PacketIdentifier(19)); + let packet = Packet::Puback(Pid::try_from(19).unwrap()); let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&packet, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Puback(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Puback(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_pubrec() { - let packet = Packet::Pubrec(PacketIdentifier(19)); + let packet = Packet::Pubrec(Pid::try_from(19).unwrap()); let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&packet, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Pubrec(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Pubrec(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_pubrel() { - let packet = Packet::Pubrel(PacketIdentifier(19)); + let packet = Packet::Pubrel(Pid::try_from(19).unwrap()); let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&packet, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - println!("{:?}", decoded); - match decoded { - Some(Packet::Pubrel(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Pubrel(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_pubcomp() { - let packet = Packet::PubComp(PacketIdentifier(19)); + let packet = Packet::Pubcomp(Pid::try_from(19).unwrap()); let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&packet, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::PubComp(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Pubcomp(_))) => assert!(true), + err => assert!(false, err), } } @@ -132,17 +102,14 @@ fn test_subscribe() { qos: QoS::ExactlyOnce, }; let packet = Subscribe { - pid: PacketIdentifier(345), + pid: Pid::try_from(345).unwrap(), topics: vec![stopic], }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::Subscribe(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Subscribe(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Subscribe(packet), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Subscribe(_))) => assert!(true), + err => assert!(false, err), } } @@ -150,86 +117,68 @@ fn test_subscribe() { fn test_suback() { let return_code = SubscribeReturnCodes::Success(QoS::ExactlyOnce); let packet = Suback { - pid: PacketIdentifier(12321), + pid: Pid::try_from(12321).unwrap(), return_codes: vec![return_code], }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::SubAck(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::SubAck(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Suback(packet), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Suback(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_unsubscribe() { let packet = Unsubscribe { - pid: PacketIdentifier(12321), + pid: Pid::try_from(12321).unwrap(), topics: vec!["a/b".to_string()], }; let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::UnSubscribe(packet), &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::UnSubscribe(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Unsubscribe(packet), &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Unsubscribe(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_unsuback() { - let packet = Packet::UnSubAck(PacketIdentifier(19)); + let packet = Packet::Unsuback(Pid::try_from(19).unwrap()); let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&packet, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::UnSubAck(_c)) => { - assert!(true); - } - _ => assert!(false), + encode(&packet, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Unsuback(_))) => assert!(true), + err => assert!(false, err), } } #[test] fn test_ping_req() { let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::PingReq, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::PingReq) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Pingreq, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Pingreq)) => assert!(true), + err => assert!(false, err), } } #[test] fn test_ping_resp() { let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::PingResp, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::PingResp) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Pingresp, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Pingresp)) => assert!(true), + err => assert!(false, err), } } #[test] fn test_disconnect() { let mut buffer = BytesMut::with_capacity(1024); - let _ = encoder::encode(&Packet::Disconnect, &mut buffer); - let decoded = decoder::decode(&mut buffer).unwrap(); - match decoded { - Some(Packet::Disconnect) => { - assert!(true); - } - _ => assert!(false), + encode(&Packet::Disconnect, &mut buffer).unwrap(); + match decode(&mut buffer) { + Ok(Some(Packet::Disconnect)) => assert!(true), + err => assert!(false, err), } } diff --git a/src/header.rs b/src/header.rs deleted file mode 100644 index d8195b4..0000000 --- a/src/header.rs +++ /dev/null @@ -1,113 +0,0 @@ -use crate::QoS; -use std::io; - -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum PacketType { - Connect, - Connack, - Publish, - Puback, - Pubrec, - Pubrel, - PubComp, - Subscribe, - SubAck, - UnSubscribe, - UnSubAck, - PingReq, - PingResp, - Disconnect, -} -impl PacketType { - #[inline] - pub fn from_hd(hd: u8) -> Result { - Self::from_u8(hd >> 4) - } - pub fn to_u8(&self) -> u8 { - match *self { - PacketType::Connect => 1, - PacketType::Connack => 2, - PacketType::Publish => 3, - PacketType::Puback => 4, - PacketType::Pubrec => 5, - PacketType::Pubrel => 6, - PacketType::PubComp => 7, - PacketType::Subscribe => 8, - PacketType::SubAck => 9, - PacketType::UnSubscribe => 10, - PacketType::UnSubAck => 11, - PacketType::PingReq => 12, - PacketType::PingResp => 13, - PacketType::Disconnect => 14, - } - } - pub fn from_u8(byte: u8) -> Result { - match byte { - 1 => Ok(PacketType::Connect), - 2 => Ok(PacketType::Connack), - 3 => Ok(PacketType::Publish), - 4 => Ok(PacketType::Puback), - 5 => Ok(PacketType::Pubrec), - 6 => Ok(PacketType::Pubrel), - 7 => Ok(PacketType::PubComp), - 8 => Ok(PacketType::Subscribe), - 9 => Ok(PacketType::SubAck), - 10 => Ok(PacketType::UnSubscribe), - 11 => Ok(PacketType::UnSubAck), - 12 => Ok(PacketType::PingReq), - 13 => Ok(PacketType::PingResp), - 14 => Ok(PacketType::Disconnect), - _ => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Unsupported packet type".to_string(), - )), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Header { - hd: u8, - packet_type: PacketType, - len: usize, -} - -impl Header { - pub fn new(hd: u8, len: usize) -> Result { - Ok(Header { - hd, - len, - packet_type: PacketType::from_hd(hd)?, - }) - } - pub fn packet(&self) -> PacketType { - self.packet_type - } - #[inline] - pub fn len(&self) -> usize { - self.len - } - #[inline] - pub fn dup(&self) -> bool { - (self.hd & 0b1000) != 0 - } - #[inline] - pub fn qos(&self) -> Result { - QoS::from_hd(self.hd) - } - #[inline] - pub fn retain(&self) -> bool { - (self.hd & 1) != 0 - } -} - -/* TESTS */ -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn header() { - let h = Header::new(0b00010000, 0).unwrap(); - assert_eq!(h.packet(), PacketType::Connect) - } -} diff --git a/src/lib.rs b/src/lib.rs index 4ad13f3..3dc15cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,52 @@ +//! `mqttrs` is a codec for the MQTT protocol. +//! +//! The API aims to be straightforward and composable, usable with plain `std` or with a framework +//! like [tokio]. The decoded packet is help in a [Packet] struct, and the encoded bytes in a +//! [bytes::BytesMut] struct. Convert between the two using [encode()] and [decode()]. Almost all +//! struct fields can be accessed directly, to create or read packets. +//! +//! It currently targets [MQTT 3.1], with [MQTT 5] support planned. +//! +//! ``` +//! use mqttrs::*; +//! use bytes::BytesMut; +//! +//! // Allocate buffer. +//! let mut buf = BytesMut::with_capacity(1024); +//! +//! // Encode an MQTT Connect packet. +//! let pkt = Packet::Connect(Connect { protocol: Protocol::MQTT311, +//! keep_alive: 30, +//! client_id: "doc_client".into(), +//! clean_session: true, +//! last_will: None, +//! username: None, +//! password: None }); +//! assert!(encode(&pkt, &mut buf).is_ok()); +//! assert_eq!(&buf[14..], "doc_client".as_bytes()); +//! let mut encoded = buf.clone(); +//! +//! // Decode one packet. The buffer will advance to the next packet. +//! assert_eq!(Ok(Some(pkt)), decode(&mut buf)); +//! +//! // Example decode failures. +//! let mut incomplete = encoded.split_to(10); +//! assert_eq!(Ok(None), decode(&mut incomplete)); +//! let mut garbage = BytesMut::from(vec![0u8,0,0,0]); +//! assert_eq!(Err(Error::InvalidHeader), decode(&mut garbage)); +//! ``` +//! +//! [MQTT 3.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html +//! [MQTT 5]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html +//! [tokio]: https://tokio.rs/ +//! [Packet]: enum.Packet.html +//! [encode()]: fn.encode.html +//! [decode()]: fn.decode.html +//! [bytes::BytesMut]: https://docs.rs/bytes/0.4.12/bytes/struct.BytesMut.html + mod connect; -mod header; +mod decoder; +mod encoder; mod packet; mod publish; mod subscribe; @@ -12,17 +59,12 @@ mod decoder_test; #[cfg(test)] mod encoder_test; -pub mod decoder; -pub mod encoder; - pub use crate::{ - connect::{Connack, Connect}, - header::{Header, PacketType}, - packet::Packet, + connect::{Connack, Connect, ConnectReturnCode, LastWill, Protocol}, + decoder::decode, + encoder::encode, + packet::{Packet, PacketType}, publish::Publish, subscribe::{Suback, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe}, - utils::{ConnectReturnCode, LastWill, PacketIdentifier, Protocol, QoS}, + utils::{Error, Pid, QoS, QosPid}, }; - -const MULTIPLIER: usize = 0x80 * 0x80 * 0x80 * 0x80; -const MAX_PAYLOAD_SIZE: usize = 268435455; diff --git a/src/packet.rs b/src/packet.rs index 29b7a25..afdd6a5 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,19 +1,111 @@ -use crate::{Connack, Connect, PacketIdentifier, Publish, Suback, Subscribe, Unsubscribe}; +use crate::*; +/// Base enum for all MQTT packet types. +/// +/// This is the main type you'll be interacting with, as an output of [`decode()`] and an input of +/// [`encode()`]. Most variants can be constructed directly without using methods. +/// +/// ``` +/// # use mqttrs::*; +/// // Simplest form +/// let pkt = Packet::Connack(Connack { session_present: false, +/// code: ConnectReturnCode::Accepted }); +/// // Using `Into` trait +/// let publish = Publish { dup: false, +/// qospid: QosPid::AtMostOnce, +/// retain: false, +/// topic_name: "to/pic".into(), +/// payload: "payload".into() }; +/// let pkt: Packet = publish.into(); +/// // Identifyer-only packets +/// let pkt = Packet::Puback(Pid::try_from(42).unwrap()); +/// ``` +/// +/// [`encode()`]: fn.encode.html +/// [`decode()`]: fn.decode.html #[derive(Debug, Clone, PartialEq)] pub enum Packet { + /// [MQTT 3.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028) Connect(Connect), + /// [MQTT 3.2](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033) Connack(Connack), + /// [MQTT 3.3](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037) Publish(Publish), - Puback(PacketIdentifier), - Pubrec(PacketIdentifier), - Pubrel(PacketIdentifier), - PubComp(PacketIdentifier), + /// [MQTT 3.4](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718043) + Puback(Pid), + /// [MQTT 3.5](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718048) + Pubrec(Pid), + /// [MQTT 3.6](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718053) + Pubrel(Pid), + /// [MQTT 3.7](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718058) + Pubcomp(Pid), + /// [MQTT 3.8](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718063) Subscribe(Subscribe), - SubAck(Suback), - UnSubscribe(Unsubscribe), - UnSubAck(PacketIdentifier), - PingReq, - PingResp, + /// [MQTT 3.9](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718068) + Suback(Suback), + /// [MQTT 3.10](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718072) + Unsubscribe(Unsubscribe), + /// [MQTT 3.11](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718077) + Unsuback(Pid), + /// [MQTT 3.12](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718081) + Pingreq, + /// [MQTT 3.13](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718086) + Pingresp, + /// [MQTT 3.14](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718090) + Disconnect, +} +impl Packet { + /// Return the packet type variant. + /// + /// This can be used for matching, categorising, debuging, etc. Most users will match directly + /// on `Packet` instead. + pub fn get_type(&self) -> PacketType { + match self { + Packet::Connect(_) => PacketType::Connect, + Packet::Connack(_) => PacketType::Connack, + Packet::Publish(_) => PacketType::Publish, + Packet::Puback(_) => PacketType::Puback, + Packet::Pubrec(_) => PacketType::Pubrec, + Packet::Pubrel(_) => PacketType::Pubrel, + Packet::Pubcomp(_) => PacketType::Pubcomp, + Packet::Subscribe(_) => PacketType::Subscribe, + Packet::Suback(_) => PacketType::Suback, + Packet::Unsubscribe(_) => PacketType::Unsubscribe, + Packet::Unsuback(_) => PacketType::Unsuback, + Packet::Pingreq => PacketType::Pingreq, + Packet::Pingresp => PacketType::Pingresp, + Packet::Disconnect => PacketType::Disconnect, + } + } +} +macro_rules! packet_from { + ($($t:ident),+) => { + $( + impl From<$t> for Packet { + fn from(p: $t) -> Self { + Packet::$t(p) + } + } + )+ + } +} +packet_from!(Connect, Connack, Publish, Subscribe, Suback, Unsubscribe); + +/// Packet type variant, without the associated data. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum PacketType { + Connect, + Connack, + Publish, + Puback, + Pubrec, + Pubrel, + Pubcomp, + Subscribe, + Suback, + Unsubscribe, + Unsuback, + Pingreq, + Pingresp, Disconnect, } diff --git a/src/publish.rs b/src/publish.rs index 6357ce8..d204af3 100644 --- a/src/publish.rs +++ b/src/publish.rs @@ -1,64 +1,70 @@ -use crate::{encoder, utils, Header, PacketIdentifier, QoS}; -use bytes::{Buf, BufMut, BytesMut, IntoBuf}; -use std::io; +use crate::{decoder::*, encoder::*, *}; +use bytes::{BufMut, BytesMut}; +/// Publish packet ([MQTT 3.3]). +/// +/// [MQTT 3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037 #[derive(Debug, Clone, PartialEq)] pub struct Publish { pub dup: bool, - pub qos: QoS, + pub qospid: QosPid, pub retain: bool, pub topic_name: String, - pub pid: Option, pub payload: Vec, } impl Publish { - pub fn from_buffer(header: &Header, buffer: &mut BytesMut) -> Result { - let topic_name = utils::read_string(buffer); + pub(crate) fn from_buffer(header: &Header, buffer: &mut BytesMut) -> Result { + let topic_name = read_string(buffer)?; - let pid = if header.qos()? == QoS::AtMostOnce { - None - } else { - Some(PacketIdentifier(buffer.split_to(2).into_buf().get_u16_be())) + let qospid = match header.qos { + QoS::AtMostOnce => QosPid::AtMostOnce, + QoS::AtLeastOnce => QosPid::AtLeastOnce(Pid::from_buffer(buffer)?), + QoS::ExactlyOnce => QosPid::ExactlyOnce(Pid::from_buffer(buffer)?), }; let payload = buffer.to_vec(); Ok(Publish { - dup: header.dup(), - qos: header.qos()?, - retain: header.retain(), + dup: header.dup, + qospid, + retain: header.retain, topic_name, - pid, payload, }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error> { + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { // Header - let mut header_u8: u8 = 0b00110000 as u8; - header_u8 |= (self.qos.to_u8()) << 1; + let mut header_u8: u8 = match self.qospid { + QosPid::AtMostOnce => 0b00110000, + QosPid::AtLeastOnce(_) => 0b00110010, + QosPid::ExactlyOnce(_) => 0b00110100, + }; if self.dup { header_u8 |= 0b00001000 as u8; }; if self.retain { header_u8 |= 0b00000001 as u8; }; + check_remaining(buffer, 1)?; buffer.put(header_u8); // Length: topic (2+len) + pid (0/2) + payload (len) let length = self.topic_name.len() - + match self.qos { - QoS::AtMostOnce => 2, + + match self.qospid { + QosPid::AtMostOnce => 2, _ => 4, } + self.payload.len(); - encoder::write_length(length, buffer)?; + write_length(length, buffer)?; // Topic - encoder::write_string(self.topic_name.as_ref(), buffer)?; + write_string(self.topic_name.as_ref(), buffer)?; // Pid - if self.qos != QoS::AtMostOnce { - buffer.put_u16_be(self.pid.unwrap().0 as u16); + match self.qospid { + QosPid::AtMostOnce => (), + QosPid::AtLeastOnce(pid) => pid.to_buffer(buffer)?, + QosPid::ExactlyOnce(pid) => pid.to_buffer(buffer)?, } // Payload diff --git a/src/subscribe.rs b/src/subscribe.rs index 7838a6c..de8c58e 100644 --- a/src/subscribe.rs +++ b/src/subscribe.rs @@ -1,20 +1,29 @@ -use crate::{encoder, utils, PacketIdentifier, QoS}; +use crate::{decoder::*, encoder::*, *}; use bytes::{Buf, BufMut, BytesMut, IntoBuf}; -use std::io; +/// Subscribe topic. +/// +/// [Subscribe] packets contain a `Vec` of those. +/// +/// [Subscribe]: struct.Subscribe.html #[derive(Debug, Clone, PartialEq)] pub struct SubscribeTopic { pub topic_path: String, pub qos: QoS, } +/// Subscribe return value. +/// +/// [Suback] packets contain a `Vec` of those. +/// +/// [Suback]: struct.Subscribe.html #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SubscribeReturnCodes { Success(QoS), Failure, } impl SubscribeReturnCodes { - pub fn to_u8(&self) -> u8 { + pub(crate) fn to_u8(&self) -> u8 { match *self { SubscribeReturnCodes::Failure => 0x80, SubscribeReturnCodes::Success(qos) => qos.to_u8(), @@ -22,30 +31,39 @@ impl SubscribeReturnCodes { } } +/// Subscribe packet ([MQTT 3.8]). +/// +/// [MQTT 3.8]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718063 #[derive(Debug, Clone, PartialEq)] pub struct Subscribe { - pub pid: PacketIdentifier, + pub pid: Pid, pub topics: Vec, } +/// Subsack packet ([MQTT 3.9]). +/// +/// [MQTT 3.9]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718068 #[derive(Debug, Clone, PartialEq)] pub struct Suback { - pub pid: PacketIdentifier, + pub pid: Pid, pub return_codes: Vec, } +/// Unsubscribe packet ([MQTT 3.10]). +/// +/// [MQTT 3.10]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718072 #[derive(Debug, Clone, PartialEq)] pub struct Unsubscribe { - pub pid: PacketIdentifier, + pub pid: Pid, pub topics: Vec, } impl Subscribe { - pub fn from_buffer(buffer: &mut BytesMut) -> Result { - let pid = PacketIdentifier(buffer.split_to(2).into_buf().get_u16_be()); + pub(crate) fn from_buffer(buffer: &mut BytesMut) -> Result { + let pid = Pid::from_buffer(buffer)?; let mut topics: Vec = Vec::new(); while buffer.len() != 0 { - let topic_path = utils::read_string(buffer); + let topic_path = read_string(buffer)?; let qos = QoS::from_u8(buffer.split_to(1).into_buf().get_u8())?; let topic = SubscribeTopic { topic_path, qos }; topics.push(topic); @@ -53,8 +71,9 @@ impl Subscribe { Ok(Subscribe { pid, topics }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error> { + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { let header_u8: u8 = 0b10000010; + check_remaining(buffer, 1)?; buffer.put(header_u8); // Length: pid(2) + topic.for_each(2+len + qos(1)) @@ -62,14 +81,14 @@ impl Subscribe { for topic in &self.topics { length += topic.topic_path.len() + 2 + 1; } - encoder::write_length(length, buffer)?; + write_length(length, buffer)?; // Pid - buffer.put_u16_be(self.pid.0); + self.pid.to_buffer(buffer)?; // Topics for topic in &self.topics { - encoder::write_string(topic.topic_path.as_ref(), buffer)?; + write_string(topic.topic_path.as_ref(), buffer)?; buffer.put(topic.qos.to_u8()); } @@ -78,37 +97,37 @@ impl Subscribe { } impl Unsubscribe { - pub fn from_buffer(buffer: &mut BytesMut) -> Result { - let pid = PacketIdentifier(buffer.split_to(2).into_buf().get_u16_be()); + pub(crate) fn from_buffer(buffer: &mut BytesMut) -> Result { + let pid = Pid::from_buffer(buffer)?; let mut topics: Vec = Vec::new(); while buffer.len() != 0 { - let topic_path = utils::read_string(buffer); + let topic_path = read_string(buffer)?; topics.push(topic_path); } Ok(Unsubscribe { pid, topics }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error>{ - let header_u8 : u8 = 0b10100010; - let PacketIdentifier(pid) = self.pid; + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let header_u8: u8 = 0b10100010; let mut length = 2; - for topic in &self.topics{ + for topic in &self.topics { length += 2 + topic.len(); } - + check_remaining(buffer, 1)?; buffer.put(header_u8); - encoder::write_length(length, buffer)?; - buffer.put_u16_be(pid as u16); - for topic in&self.topics{ - encoder::write_string(topic.as_ref(), buffer)?; + + write_length(length, buffer)?; + self.pid.to_buffer(buffer)?; + for topic in &self.topics { + write_string(topic.as_ref(), buffer)?; } Ok(()) } } impl Suback { - pub fn from_buffer(buffer: &mut BytesMut) -> Result { - let pid = PacketIdentifier(buffer.split_to(2).into_buf().get_u16_be()); + pub(crate) fn from_buffer(buffer: &mut BytesMut) -> Result { + let pid = Pid::from_buffer(buffer)?; let mut return_codes: Vec = Vec::new(); while buffer.len() != 0 { let code = buffer.split_to(1).into_buf().get_u8(); @@ -121,14 +140,14 @@ impl Suback { } Ok(Suback { return_codes, pid }) } - pub fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), io::Error> { + pub(crate) fn to_buffer(&self, buffer: &mut BytesMut) -> Result<(), Error> { let header_u8: u8 = 0b10010000; - let PacketIdentifier(pid) = self.pid; let length = 2 + self.return_codes.len(); - + check_remaining(buffer, 1)?; buffer.put(header_u8); - encoder::write_length(length, buffer)?; - buffer.put_u16_be(pid); + + write_length(length, buffer)?; + self.pid.to_buffer(buffer)?; for rc in &self.return_codes { buffer.put(rc.to_u8()); } diff --git a/src/utils.rs b/src/utils.rs index 554ded3..4438c2a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,121 +1,194 @@ -use bytes::{Buf, BytesMut, IntoBuf}; -use std::io; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; +use std::{ + error::Error as ErrorTrait, + fmt, + io::{Error as IoError, ErrorKind}, + num::NonZeroU16, +}; -/// Packet Identifier, for ack purposes. +/// Errors returned by [`encode()`] and [`decode()`]. /// -/// Note that the spec disallows a pid of 0 ([MQTT-2.3.1-1] for mqtt3, [MQTT-2.2.1-3] for mqtt5). -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct PacketIdentifier(pub u16); +/// [`encode()`]: fn.encode.html +/// [`decode()`]: fn.decode.html +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Error { + /// Not enough space in the write buffer. + /// + /// It is the caller's responsiblity to pass a big enough buffer to `encode()`. + WriteZero, + /// Tried to encode or decode a ProcessIdentifier==0. + InvalidPid, + /// Tried to decode a QoS > 2. + InvalidQos(u8), + /// Tried to decode a ConnectReturnCode > 5. + InvalidConnectReturnCode(u8), + /// Tried to decode an unknown protocol. + InvalidProtocol(String, u8), + /// Tried to decode an invalid fixed header (packet type, flags, or remaining_length). + InvalidHeader, + /// Trying to encode/decode an invalid length. + /// + /// The difference with `BufferFull`/`BufferIncomplete` is that it refers to an invalid/corrupt + /// length rather than a buffer size issue. + InvalidLength, + /// Trying to decode a non-utf8 string. + InvalidString(std::str::Utf8Error), + /// Catch-all error when converting from `std::io::Error`. + /// + /// You'll hopefully never see this. + IoError(ErrorKind, String), +} +impl ErrorTrait for Error {} +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + } +} +impl From for IoError { + fn from(err: Error) -> IoError { + match err { + Error::WriteZero => IoError::new(ErrorKind::WriteZero, err), + _ => IoError::new(ErrorKind::InvalidData, err), + } + } +} +impl From for Error { + fn from(err: IoError) -> Error { + match err.kind() { + ErrorKind::WriteZero => Error::WriteZero, + k => Error::IoError(k, format!("{}",err)), + } + } +} -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Protocol { - MQIsdp(u8), - MQTT(u8), +/// Packet Identifier. +/// +/// For packets with [`QoS::AtLeastOne` or `QoS::ExactlyOnce`] delivery. +/// +/// ```rust +/// # use mqttrs::{Pid, Packet}; +/// let pid = Pid::try_from(42).expect("illegal pid value"); +/// let next_pid = pid + 1; +/// let pending_acks = std::collections::HashMap::::new(); +/// ``` +/// +/// The spec ([MQTT-2.3.1-1], [MQTT-2.2.1-3]) disallows a pid of 0. +/// +/// [`QoS::AtLeastOne` or `QoS::ExactlyOnce`]: enum.QoS.html +/// [MQTT-2.3.1-1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718025 +/// [MQTT-2.2.1-3]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901026 +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Pid(NonZeroU16); +impl Pid { + /// Returns a new `Pid` with value `1`. + pub fn new() -> Self { + Pid(NonZeroU16::new(1).unwrap()) + } + /// Returns a new `Pid` with specified value. + // Not using std::convert::TryFrom so that don't have to depend on rust 1.34. + pub fn try_from(u: u16) -> Result { + match NonZeroU16::new(u) { + Some(nz) => Ok(Pid(nz)), + None => Err(Error::InvalidPid), + } + } + /// Get the `Pid` as a raw `u16`. + pub fn get(self) -> u16 { + self.0.get() + } + pub(crate) fn from_buffer(buf: &mut BytesMut) -> Result { + Self::try_from(buf.split_to(2).into_buf().get_u16_be()) + } + pub(crate) fn to_buffer(self, buf: &mut BytesMut) -> Result<(), Error> { + Ok(buf.put_u16_be(self.get())) + } +} +impl std::ops::Add for Pid { + type Output = Pid; + fn add(self, u: u16) -> Pid { + let n = self.get().wrapping_add(u); + Pid(NonZeroU16::new(if n == 0 { 1 } else { n }).unwrap()) + } +} +impl std::ops::Sub for Pid { + type Output = Pid; + fn sub(self, u: u16) -> Pid { + let n = self.get().wrapping_sub(u); + Pid(NonZeroU16::new(if n == 0 { std::u16::MAX } else { n }).unwrap()) + } } +/// Packet delivery [Quality of Service] level. +/// +/// [Quality of Service]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718099 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum QoS { + /// `QoS 0`. No ack needed. AtMostOnce, + /// `QoS 1`. One ack needed. AtLeastOnce, + /// `QoS 2`. Two acks needed. ExactlyOnce, } impl QoS { - pub fn to_u8(&self) -> u8 { + pub(crate) fn to_u8(&self) -> u8 { match *self { QoS::AtMostOnce => 0, QoS::AtLeastOnce => 1, QoS::ExactlyOnce => 2, } } - pub fn from_u8(byte: u8) -> Result { + pub(crate) fn from_u8(byte: u8) -> Result { match byte { 0 => Ok(QoS::AtMostOnce), 1 => Ok(QoS::AtLeastOnce), 2 => Ok(QoS::ExactlyOnce), - _ => Err(io::Error::new(io::ErrorKind::InvalidData, "")), + n => Err(Error::InvalidQos(n)), } } - #[inline] - pub fn from_hd(hd: u8) -> Result { - Self::from_u8((hd & 0b110) >> 1) - } -} - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum ConnectReturnCode { - Accepted, - RefusedProtocolVersion, - RefusedIdentifierRejected, - ServerUnavailable, - BadUsernamePassword, - NotAuthorized, } -#[derive(Debug, Clone, PartialEq)] -pub struct LastWill { - pub topic: String, - pub message: String, - pub qos: QoS, - pub retain: bool, +/// Combined [`QoS`]/[`Pid`]. +/// +/// Used only in [`Publish`] packets. +/// +/// [`Publish`]: struct.Publish.html +/// [`QoS`]: enum.QoS.html +/// [`Pid`]: struct.Pid.html +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QosPid { + AtMostOnce, + AtLeastOnce(Pid), + ExactlyOnce(Pid), } - -impl Protocol { - pub fn new(name: &str, level: u8) -> Result { - match name { - "MQIsdp" => match level { - 3 => Ok(Protocol::MQIsdp(3)), - _ => Err(io::Error::new(io::ErrorKind::InvalidData, "")), - }, - "MQTT" => match level { - 4 => Ok(Protocol::MQTT(4)), - _ => Err(io::Error::new(io::ErrorKind::InvalidData, "")), - }, - _ => Err(io::Error::new(io::ErrorKind::InvalidData, "")), +impl QosPid { + #[cfg(test)] + pub(crate) fn from_u8u16(qos: u8, pid: u16) -> Self { + match qos { + 0 => QosPid::AtMostOnce, + 1 => QosPid::AtLeastOnce(Pid::try_from(pid).expect("pid == 0")), + 2 => QosPid::ExactlyOnce(Pid::try_from(pid).expect("pid == 0")), + _ => panic!("Qos > 2"), } } - - pub fn name(&self) -> &'static str { + /// Extract the [`Pid`] from a `QosPid`, if any. + /// + /// [`Pid`]: struct.Pid.html + pub fn pid(self) -> Option { match self { - &Protocol::MQIsdp(_) => "MQIsdp", - &Protocol::MQTT(_) => "MQTT", + QosPid::AtMostOnce => None, + QosPid::AtLeastOnce(p) => Some(p), + QosPid::ExactlyOnce(p) => Some(p), } } - - pub fn level(&self) -> u8 { + /// Extract the [`QoS`] from a `QosPid`. + /// + /// [`QoS`]: enum.QoS.html + pub fn qos(self) -> QoS { match self { - &Protocol::MQIsdp(level) => level, - &Protocol::MQTT(level) => level, - } - } -} - -impl ConnectReturnCode { - pub fn to_u8(&self) -> u8 { - match *self { - ConnectReturnCode::Accepted => 0, - ConnectReturnCode::RefusedProtocolVersion => 1, - ConnectReturnCode::RefusedIdentifierRejected => 2, - ConnectReturnCode::ServerUnavailable => 3, - ConnectReturnCode::BadUsernamePassword => 4, - ConnectReturnCode::NotAuthorized => 5, + QosPid::AtMostOnce => QoS::AtMostOnce, + QosPid::AtLeastOnce(_) => QoS::AtLeastOnce, + QosPid::ExactlyOnce(_) => QoS::ExactlyOnce, } } - - pub fn from_u8(byte: u8) -> Result { - match byte { - 0 => Ok(ConnectReturnCode::Accepted), - 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), - 2 => Ok(ConnectReturnCode::RefusedIdentifierRejected), - 3 => Ok(ConnectReturnCode::ServerUnavailable), - 4 => Ok(ConnectReturnCode::BadUsernamePassword), - 5 => Ok(ConnectReturnCode::NotAuthorized), - _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "")), - } - } -} - -pub fn read_string(buffer: &mut BytesMut) -> String { - let length = buffer.split_to(2).into_buf().get_u16_be(); - let byts = buffer.split_to(length as usize); - return String::from_utf8(byts.to_vec()).unwrap().to_string(); }