From 2e8fb00cf8dcb0ce23de57d8c535ed44e03e4834 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Thu, 25 Jul 2024 10:19:51 -0400 Subject: [PATCH] RUST-1992 Factor raw bson encoding out of RawDocumentBuf (#486) --- src/raw/document_buf.rs | 104 ++--------------------------- src/raw/document_buf/raw_writer.rs | 96 ++++++++++++++++++++++++++ src/raw/serde/seeded_visitor.rs | 19 +----- src/ser/mod.rs | 15 ++--- src/ser/raw/mod.rs | 3 +- src/ser/raw/value_serializer.rs | 8 +-- 6 files changed, 118 insertions(+), 127 deletions(-) create mode 100644 src/raw/document_buf/raw_writer.rs diff --git a/src/raw/document_buf.rs b/src/raw/document_buf.rs index 6a508a35..5a99fa72 100644 --- a/src/raw/document_buf.rs +++ b/src/raw/document_buf.rs @@ -7,7 +7,7 @@ use std::{ use serde::{Deserialize, Serialize}; -use crate::{de::MIN_BSON_DOCUMENT_SIZE, spec::BinarySubtype, Document}; +use crate::{de::MIN_BSON_DOCUMENT_SIZE, Document}; use super::{ bson::RawBson, @@ -21,6 +21,8 @@ use super::{ Result, }; +mod raw_writer; + /// An owned BSON document (akin to [`std::path::PathBuf`]), backed by a buffer of raw BSON bytes. /// This can be created from a `Vec` or a [`crate::Document`]. /// @@ -221,103 +223,9 @@ impl RawDocumentBuf { /// /// If the provided key contains an interior null byte, this method will panic. pub fn append_ref<'a>(&mut self, key: impl AsRef, value: impl Into>) { - fn append_string(doc: &mut RawDocumentBuf, value: &str) { - doc.data - .extend(((value.as_bytes().len() + 1) as i32).to_le_bytes()); - doc.data.extend(value.as_bytes()); - doc.data.push(0); - } - - fn append_cstring(doc: &mut RawDocumentBuf, value: &str) { - if value.contains('\0') { - panic!("cstr includes interior null byte: {}", value) - } - doc.data.extend(value.as_bytes()); - doc.data.push(0); - } - - let original_len = self.data.len(); - - // write the key for the next value to the end - // the element type will replace the previous null byte terminator of the document - append_cstring(self, key.as_ref()); - - let value = value.into(); - let element_type = value.element_type(); - - match value { - RawBsonRef::Int32(i) => { - self.data.extend(i.to_le_bytes()); - } - RawBsonRef::String(s) => { - append_string(self, s); - } - RawBsonRef::Document(d) => { - self.data.extend(d.as_bytes()); - } - RawBsonRef::Array(a) => { - self.data.extend(a.as_bytes()); - } - RawBsonRef::Binary(b) => { - let len = b.len(); - self.data.extend(len.to_le_bytes()); - self.data.push(b.subtype.into()); - if let BinarySubtype::BinaryOld = b.subtype { - self.data.extend((len - 4).to_le_bytes()) - } - self.data.extend(b.bytes); - } - RawBsonRef::Boolean(b) => { - self.data.push(b as u8); - } - RawBsonRef::DateTime(dt) => { - self.data.extend(dt.timestamp_millis().to_le_bytes()); - } - RawBsonRef::DbPointer(dbp) => { - append_string(self, dbp.namespace); - self.data.extend(dbp.id.bytes()); - } - RawBsonRef::Decimal128(d) => { - self.data.extend(d.bytes()); - } - RawBsonRef::Double(d) => { - self.data.extend(d.to_le_bytes()); - } - RawBsonRef::Int64(i) => { - self.data.extend(i.to_le_bytes()); - } - RawBsonRef::RegularExpression(re) => { - append_cstring(self, re.pattern); - append_cstring(self, re.options); - } - RawBsonRef::JavaScriptCode(js) => { - append_string(self, js); - } - RawBsonRef::JavaScriptCodeWithScope(code_w_scope) => { - let len = code_w_scope.len(); - self.data.extend(len.to_le_bytes()); - append_string(self, code_w_scope.code); - self.data.extend(code_w_scope.scope.as_bytes()); - } - RawBsonRef::Timestamp(ts) => { - self.data.extend(ts.to_le_bytes()); - } - RawBsonRef::ObjectId(oid) => { - self.data.extend(oid.bytes()); - } - RawBsonRef::Symbol(s) => { - append_string(self, s); - } - RawBsonRef::Null | RawBsonRef::Undefined | RawBsonRef::MinKey | RawBsonRef::MaxKey => {} - } - - // update element type - self.data[original_len - 1] = element_type as u8; - // append trailing null byte - self.data.push(0); - // update length - let new_len = (self.data.len() as i32).to_le_bytes(); - self.data[0..4].copy_from_slice(&new_len); + raw_writer::RawWriter::new(&mut self.data) + .append(key.as_ref(), value.into()) + .expect("key should not contain interior null byte") } /// Convert this [`RawDocumentBuf`] to a [`Document`], returning an error diff --git a/src/raw/document_buf/raw_writer.rs b/src/raw/document_buf/raw_writer.rs new file mode 100644 index 00000000..1d6fae22 --- /dev/null +++ b/src/raw/document_buf/raw_writer.rs @@ -0,0 +1,96 @@ +use crate::{ + ser::{write_cstring, write_string}, + spec::BinarySubtype, + RawBsonRef, +}; + +pub(super) struct RawWriter<'a> { + data: &'a mut Vec, +} + +impl<'a> RawWriter<'a> { + pub(super) fn new(data: &'a mut Vec) -> Self { + Self { data } + } + + pub(super) fn append(&mut self, key: &str, value: RawBsonRef) -> crate::ser::Result<()> { + let original_len = self.data.len(); + self.data[original_len - 1] = value.element_type() as u8; + + write_cstring(self.data, key)?; + + match value { + RawBsonRef::Int32(i) => { + self.data.extend(i.to_le_bytes()); + } + RawBsonRef::String(s) => { + write_string(self.data, s); + } + RawBsonRef::Document(d) => { + self.data.extend(d.as_bytes()); + } + RawBsonRef::Array(a) => { + self.data.extend(a.as_bytes()); + } + RawBsonRef::Binary(b) => { + let len = b.len(); + self.data.extend(len.to_le_bytes()); + self.data.push(b.subtype.into()); + if let BinarySubtype::BinaryOld = b.subtype { + self.data.extend((len - 4).to_le_bytes()) + } + self.data.extend(b.bytes); + } + RawBsonRef::Boolean(b) => { + self.data.push(b as u8); + } + RawBsonRef::DateTime(dt) => { + self.data.extend(dt.timestamp_millis().to_le_bytes()); + } + RawBsonRef::DbPointer(dbp) => { + write_string(self.data, dbp.namespace); + self.data.extend(dbp.id.bytes()); + } + RawBsonRef::Decimal128(d) => { + self.data.extend(d.bytes()); + } + RawBsonRef::Double(d) => { + self.data.extend(d.to_le_bytes()); + } + RawBsonRef::Int64(i) => { + self.data.extend(i.to_le_bytes()); + } + RawBsonRef::RegularExpression(re) => { + write_cstring(self.data, re.pattern)?; + write_cstring(self.data, re.options)?; + } + RawBsonRef::JavaScriptCode(js) => { + write_string(self.data, js); + } + RawBsonRef::JavaScriptCodeWithScope(code_w_scope) => { + let len = code_w_scope.len(); + self.data.extend(len.to_le_bytes()); + write_string(self.data, code_w_scope.code); + self.data.extend(code_w_scope.scope.as_bytes()); + } + RawBsonRef::Timestamp(ts) => { + self.data.extend(ts.to_le_bytes()); + } + RawBsonRef::ObjectId(oid) => { + self.data.extend(oid.bytes()); + } + RawBsonRef::Symbol(s) => { + write_string(self.data, s); + } + RawBsonRef::Null | RawBsonRef::Undefined | RawBsonRef::MinKey | RawBsonRef::MaxKey => {} + } + + // append trailing null byte + self.data.push(0); + // update length + let new_len = (self.data.len() as i32).to_le_bytes(); + self.data[0..4].copy_from_slice(&new_len); + + Ok(()) + } +} diff --git a/src/raw/serde/seeded_visitor.rs b/src/raw/serde/seeded_visitor.rs index fa0fd31a..1d6f7ca9 100644 --- a/src/raw/serde/seeded_visitor.rs +++ b/src/raw/serde/seeded_visitor.rs @@ -7,6 +7,7 @@ use serde::{ use crate::{ raw::RAW_BSON_NEWTYPE, + ser::{write_cstring, write_string}, spec::{BinarySubtype, ElementType}, RawBson, RawBsonRef, @@ -119,26 +120,12 @@ impl<'a, 'de> SeededVisitor<'a, 'de> { /// Appends a cstring to the buffer. Returns an error if the given string contains a null byte. fn append_cstring(&mut self, key: &str) -> Result<(), String> { - let key_bytes = key.as_bytes(); - if key_bytes.contains(&0) { - return Err(format!("key contains interior null byte: {}", key)); - } - - self.buffer.append_bytes(key_bytes); - self.buffer.push_byte(0); - - Ok(()) + write_cstring(self.buffer.get_owned_buffer(), key).map_err(|e| e.to_string()) } /// Appends a string and its length to the buffer. fn append_string(&mut self, s: &str) { - let bytes = s.as_bytes(); - - // Add 1 to account for null byte. - self.append_length_bytes((bytes.len() + 1) as i32); - - self.buffer.append_bytes(bytes); - self.buffer.push_byte(0); + write_string(self.buffer.get_owned_buffer(), s) } /// Converts the given length into little-endian bytes and appends the bytes to the buffer. diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 6adf87ec..b35e51cd 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -40,19 +40,18 @@ use crate::{ }; use ::serde::{ser::Error as SerdeError, Serialize}; -fn write_string(writer: &mut W, s: &str) -> Result<()> { - writer.write_all(&(s.len() as i32 + 1).to_le_bytes())?; - writer.write_all(s.as_bytes())?; - writer.write_all(b"\0")?; - Ok(()) +pub(crate) fn write_string(buf: &mut Vec, s: &str) { + buf.extend(&(s.len() as i32 + 1).to_le_bytes()); + buf.extend(s.as_bytes()); + buf.push(0); } -fn write_cstring(writer: &mut W, s: &str) -> Result<()> { +pub(crate) fn write_cstring(buf: &mut Vec, s: &str) -> Result<()> { if s.contains('\0') { return Err(Error::InvalidCString(s.into())); } - writer.write_all(s.as_bytes())?; - writer.write_all(b"\0")?; + buf.extend(s.as_bytes()); + buf.push(0); Ok(()) } diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index 2fb739c3..103627d2 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -199,7 +199,8 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_str(self, v: &str) -> Result { self.update_element_type(ElementType::String)?; - write_string(&mut self.bytes, v) + write_string(&mut self.bytes, v); + Ok(()) } #[inline] diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 6f7b9945..76d1a1c0 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -265,7 +265,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), subtype)?; } SerializationStep::Symbol | SerializationStep::DbPointerRef => { - write_string(&mut self.root_serializer.bytes, v)?; + write_string(&mut self.root_serializer.bytes, v); } SerializationStep::RegExPattern => { write_cstring(&mut self.root_serializer.bytes, v)?; @@ -278,7 +278,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { write_cstring(&mut self.root_serializer.bytes, sorted.as_str())?; } SerializationStep::Code => { - write_string(&mut self.root_serializer.bytes, v)?; + write_string(&mut self.root_serializer.bytes, v); } SerializationStep::CodeWithScopeCode => { self.state = SerializationStep::CodeWithScopeScope { @@ -313,7 +313,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { scope: RawDocument::from_bytes(v).map_err(Error::custom)?, }; write_i32(&mut self.root_serializer.bytes, raw.len())?; - write_string(&mut self.root_serializer.bytes, code)?; + write_string(&mut self.root_serializer.bytes, code); self.root_serializer.bytes.write_all(v)?; self.state = SerializationStep::Done; Ok(()) @@ -590,7 +590,7 @@ impl<'a> CodeWithScopeSerializer<'a> { fn start(code: &str, rs: &'a mut Serializer) -> Result { let start = rs.bytes.len(); write_i32(&mut rs.bytes, 0)?; // placeholder length - write_string(&mut rs.bytes, code)?; + write_string(&mut rs.bytes, code); let doc = DocumentSerializer::start(rs)?; Ok(Self { start, doc })