From 0ed3635fe44b5ea9184143fa355756e21ef8de32 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 9 Jul 2024 13:46:27 -0400 Subject: [PATCH] RUST-1773 Merge duplicate extjson map parsing between OwnedOrBorrowedRawBsonVisitor and SeededVisitor (#480) --- src/datetime.rs | 4 + src/raw/serde.rs | 12 ++ src/raw/serde/bson_visitor.rs | 296 +++++++++++++++++--------------- src/raw/serde/seeded_visitor.rs | 211 ++++++++++------------- 4 files changed, 261 insertions(+), 262 deletions(-) diff --git a/src/datetime.rs b/src/datetime.rs index 4c4469dc..f52b16e9 100644 --- a/src/datetime.rs +++ b/src/datetime.rs @@ -405,6 +405,10 @@ impl crate::DateTime { self.checked_duration_since(earlier) .unwrap_or(Duration::ZERO) } + + pub(crate) fn as_le_bytes(&self) -> [u8; 8] { + self.0.to_le_bytes() + } } impl fmt::Debug for crate::DateTime { diff --git a/src/raw/serde.rs b/src/raw/serde.rs index 98341fbe..1f877c1d 100644 --- a/src/raw/serde.rs +++ b/src/raw/serde.rs @@ -33,6 +33,18 @@ pub(crate) enum OwnedOrBorrowedRawBson<'a> { Borrowed(RawBsonRef<'a>), } +impl<'a> OwnedOrBorrowedRawBson<'a> { + pub(crate) fn as_ref<'b>(&'b self) -> RawBsonRef<'b> + where + 'a: 'b, + { + match self { + Self::Borrowed(r) => *r, + Self::Owned(bson) => bson.as_raw_bson_ref(), + } + } +} + impl<'a, 'de: 'a> Deserialize<'de> for OwnedOrBorrowedRawBson<'a> { fn deserialize(deserializer: D) -> Result where diff --git a/src/raw/serde/bson_visitor.rs b/src/raw/serde/bson_visitor.rs index ac2eb043..24d7af35 100644 --- a/src/raw/serde/bson_visitor.rs +++ b/src/raw/serde/bson_visitor.rs @@ -44,6 +44,160 @@ use super::{ /// A visitor used to deserialize types backed by raw BSON. pub(crate) struct OwnedOrBorrowedRawBsonVisitor; +pub(super) enum MapParse<'de> { + Leaf(OwnedOrBorrowedRawBson<'de>), + Aggregate(CowStr<'de>), +} + +impl OwnedOrBorrowedRawBsonVisitor { + pub(super) fn parse_map<'de, A>(map: &mut A) -> Result, A::Error> + where + A: serde::de::MapAccess<'de>, + { + let first_key = match map.next_key::()? { + Some(k) => k, + None => { + return Ok(MapParse::Leaf( + RawBson::Document(RawDocumentBuf::new()).into(), + )) + } + }; + Ok(MapParse::Leaf(match first_key.0.as_ref() { + "$oid" => { + let oid: ObjectId = map.next_value()?; + RawBsonRef::ObjectId(oid).into() + } + "$symbol" => { + let s: CowStr = map.next_value()?; + match s.0 { + Cow::Borrowed(s) => RawBsonRef::Symbol(s).into(), + Cow::Owned(s) => RawBson::Symbol(s).into(), + } + } + "$numberDecimalBytes" => { + let bytes: ByteBuf = map.next_value()?; + RawBsonRef::Decimal128(Decimal128::deserialize_from_slice(bytes.as_ref())?).into() + } + "$regularExpression" => { + let body: BorrowedRegexBody = map.next_value()?; + + match (body.pattern, body.options) { + (Cow::Borrowed(p), Cow::Borrowed(o)) => { + RawBsonRef::RegularExpression(RawRegexRef { + pattern: p, + options: o, + }) + .into() + } + (p, o) => RawBson::RegularExpression(Regex { + pattern: p.into_owned(), + options: o.into_owned(), + }) + .into(), + } + } + "$undefined" => { + let _: bool = map.next_value()?; + RawBsonRef::Undefined.into() + } + "$binary" => { + let v: BorrowedBinaryBody = map.next_value()?; + + if let Cow::Borrowed(bytes) = v.bytes { + RawBsonRef::Binary(RawBinaryRef { + bytes, + subtype: v.subtype.into(), + }) + .into() + } else { + RawBson::Binary(Binary { + bytes: v.bytes.into_owned(), + subtype: v.subtype.into(), + }) + .into() + } + } + "$date" => { + let date: i64 = map.next_value()?; + RawBsonRef::DateTime(DateTime::from_millis(date)).into() + } + "$timestamp" => { + let timestamp: TimestampBody = map.next_value()?; + RawBsonRef::Timestamp(Timestamp { + time: timestamp.t, + increment: timestamp.i, + }) + .into() + } + "$minKey" => { + let _ = map.next_value::()?; + RawBsonRef::MinKey.into() + } + "$maxKey" => { + let _ = map.next_value::()?; + RawBsonRef::MaxKey.into() + } + "$code" => { + let code = map.next_value::()?; + if let Some(key) = map.next_key::()? { + if key.0.as_ref() == "$scope" { + let scope = map.next_value::()?; + match (code.0, scope) { + (Cow::Borrowed(code), OwnedOrBorrowedRawDocument::Borrowed(scope)) => { + RawBsonRef::JavaScriptCodeWithScope(RawJavaScriptCodeWithScopeRef { + code, + scope, + }) + .into() + } + (code, scope) => { + RawBson::JavaScriptCodeWithScope(RawJavaScriptCodeWithScope { + code: code.into_owned(), + scope: scope.into_owned(), + }) + .into() + } + } + } else { + return Err(SerdeError::unknown_field(&key.0, &["$scope"])); + } + } else if let Cow::Borrowed(code) = code.0 { + RawBsonRef::JavaScriptCode(code).into() + } else { + RawBson::JavaScriptCode(code.0.into_owned()).into() + } + } + "$dbPointer" => { + let db_pointer: BorrowedDbPointerBody = map.next_value()?; + if let Cow::Borrowed(ns) = db_pointer.ns.0 { + RawBsonRef::DbPointer(RawDbPointerRef { + namespace: ns, + id: db_pointer.id, + }) + .into() + } else { + RawBson::DbPointer(DbPointer { + namespace: db_pointer.ns.0.into_owned(), + id: db_pointer.id, + }) + .into() + } + } + RAW_DOCUMENT_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::from_bytes(bson).map_err(SerdeError::custom)?; + RawBsonRef::Document(doc).into() + } + RAW_ARRAY_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::from_bytes(bson).map_err(SerdeError::custom)?; + RawBsonRef::Array(RawArray::from_doc(doc)).into() + } + _ => return Ok(MapParse::Aggregate(first_key)), + })) + } +} + impl<'de> Visitor<'de> for OwnedOrBorrowedRawBsonVisitor { type Value = OwnedOrBorrowedRawBson<'de>; @@ -209,145 +363,9 @@ impl<'de> Visitor<'de> for OwnedOrBorrowedRawBsonVisitor { where A: serde::de::MapAccess<'de>, { - let first_key = match map.next_key::()? { - Some(k) => k, - None => return Ok(RawBson::Document(RawDocumentBuf::new()).into()), - }; - - match first_key.0.as_ref() { - "$oid" => { - let oid: ObjectId = map.next_value()?; - Ok(RawBsonRef::ObjectId(oid).into()) - } - "$symbol" => { - let s: CowStr = map.next_value()?; - match s.0 { - Cow::Borrowed(s) => Ok(RawBsonRef::Symbol(s).into()), - Cow::Owned(s) => Ok(RawBson::Symbol(s).into()), - } - } - "$numberDecimalBytes" => { - let bytes: ByteBuf = map.next_value()?; - return Ok(RawBsonRef::Decimal128(Decimal128::deserialize_from_slice( - bytes.as_ref(), - )?) - .into()); - } - "$regularExpression" => { - let body: BorrowedRegexBody = map.next_value()?; - - match (body.pattern, body.options) { - (Cow::Borrowed(p), Cow::Borrowed(o)) => { - Ok(RawBsonRef::RegularExpression(RawRegexRef { - pattern: p, - options: o, - }) - .into()) - } - (p, o) => Ok(RawBson::RegularExpression(Regex { - pattern: p.into_owned(), - options: o.into_owned(), - }) - .into()), - } - } - "$undefined" => { - let _: bool = map.next_value()?; - Ok(RawBsonRef::Undefined.into()) - } - "$binary" => { - let v: BorrowedBinaryBody = map.next_value()?; - - if let Cow::Borrowed(bytes) = v.bytes { - Ok(RawBsonRef::Binary(RawBinaryRef { - bytes, - subtype: v.subtype.into(), - }) - .into()) - } else { - Ok(RawBson::Binary(Binary { - bytes: v.bytes.into_owned(), - subtype: v.subtype.into(), - }) - .into()) - } - } - "$date" => { - let date: i64 = map.next_value()?; - Ok(RawBsonRef::DateTime(DateTime::from_millis(date)).into()) - } - "$timestamp" => { - let timestamp: TimestampBody = map.next_value()?; - Ok(RawBsonRef::Timestamp(Timestamp { - time: timestamp.t, - increment: timestamp.i, - }) - .into()) - } - "$minKey" => { - let _ = map.next_value::()?; - Ok(RawBsonRef::MinKey.into()) - } - "$maxKey" => { - let _ = map.next_value::()?; - Ok(RawBsonRef::MaxKey.into()) - } - "$code" => { - let code = map.next_value::()?; - if let Some(key) = map.next_key::()? { - if key.0.as_ref() == "$scope" { - let scope = map.next_value::()?; - match (code.0, scope) { - (Cow::Borrowed(code), OwnedOrBorrowedRawDocument::Borrowed(scope)) => { - Ok(RawBsonRef::JavaScriptCodeWithScope( - RawJavaScriptCodeWithScopeRef { code, scope }, - ) - .into()) - } - (code, scope) => Ok(RawBson::JavaScriptCodeWithScope( - RawJavaScriptCodeWithScope { - code: code.into_owned(), - scope: scope.into_owned(), - }, - ) - .into()), - } - } else { - Err(SerdeError::unknown_field(&key.0, &["$scope"])) - } - } else if let Cow::Borrowed(code) = code.0 { - Ok(RawBsonRef::JavaScriptCode(code).into()) - } else { - Ok(RawBson::JavaScriptCode(code.0.into_owned()).into()) - } - } - "$dbPointer" => { - let db_pointer: BorrowedDbPointerBody = map.next_value()?; - if let Cow::Borrowed(ns) = db_pointer.ns.0 { - Ok(RawBsonRef::DbPointer(RawDbPointerRef { - namespace: ns, - id: db_pointer.id, - }) - .into()) - } else { - Ok(RawBson::DbPointer(DbPointer { - namespace: db_pointer.ns.0.into_owned(), - id: db_pointer.id, - }) - .into()) - } - } - RAW_DOCUMENT_NEWTYPE => { - let bson = map.next_value::<&[u8]>()?; - let doc = RawDocument::from_bytes(bson).map_err(SerdeError::custom)?; - Ok(RawBsonRef::Document(doc).into()) - } - RAW_ARRAY_NEWTYPE => { - let bson = map.next_value::<&[u8]>()?; - let doc = RawDocument::from_bytes(bson).map_err(SerdeError::custom)?; - Ok(RawBsonRef::Array(RawArray::from_doc(doc)).into()) - } - _ => { + match Self::parse_map(&mut map)? { + MapParse::Leaf(value) => Ok(value), + MapParse::Aggregate(first_key) => { let mut buffer = CowByteBuffer::new(); let seeded_visitor = SeededVisitor::new(&mut buffer); seeded_visitor.iterate_map(first_key, map)?; diff --git a/src/raw/serde/seeded_visitor.rs b/src/raw/serde/seeded_visitor.rs index e03ef7a6..fa0fd31a 100644 --- a/src/raw/serde/seeded_visitor.rs +++ b/src/raw/serde/seeded_visitor.rs @@ -4,23 +4,15 @@ use serde::{ de::{DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}, Deserializer, }; -use serde_bytes::ByteBuf; use crate::{ - de::MIN_BSON_DOCUMENT_SIZE, - extjson::models::{ - BorrowedBinaryBody, - BorrowedDbPointerBody, - BorrowedRegexBody, - TimestampBody, - }, - oid::ObjectId, - raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, + raw::RAW_BSON_NEWTYPE, spec::{BinarySubtype, ElementType}, - RawDocumentBuf, + RawBson, + RawBsonRef, }; -use super::CowStr; +use super::{CowStr, MapParse, OwnedOrBorrowedRawBson, OwnedOrBorrowedRawBsonVisitor}; /// A copy-on-write byte buffer containing raw BSON bytes. The inner value starts as `None` and /// transitions to either `Cow::Borrowed` or `Cow::Owned` depending upon the data visited. @@ -272,122 +264,95 @@ impl<'a, 'de> Visitor<'de> for SeededVisitor<'a, 'de> { where A: MapAccess<'de>, { - let first_key = match map.next_key::()? { - Some(key) => key, - None => { - self.buffer - .append_bytes(&MIN_BSON_DOCUMENT_SIZE.to_le_bytes()); - self.buffer.push_byte(0); - return Ok(ElementType::EmbeddedDocument); - } - }; - - match first_key.0.as_ref() { - "$oid" => { - let oid: ObjectId = map.next_value()?; - self.buffer.append_bytes(&oid.bytes()); - Ok(ElementType::ObjectId) - } - "$symbol" => { - let s: &str = map.next_value()?; - self.append_string(s); - Ok(ElementType::Symbol) - } - "$numberDecimalBytes" => { - let bytes: ByteBuf = map.next_value()?; - self.buffer.append_bytes(&bytes.into_vec()); - Ok(ElementType::Decimal128) - } - "$regularExpression" => { - let regex: BorrowedRegexBody = map.next_value()?; - let pattern = regex.pattern.as_ref(); - let options = regex.options.as_ref(); - - self.append_cstring(pattern).map_err(SerdeError::custom)?; - self.append_cstring(options).map_err(SerdeError::custom)?; - - Ok(ElementType::RegularExpression) - } - "$undefined" => { - let _: bool = map.next_value()?; - Ok(ElementType::Undefined) - } - "$binary" => { - let binary: BorrowedBinaryBody = map.next_value()?; - match binary.bytes { - Cow::Borrowed(borrowed_bytes) => { - self.append_borrowed_binary(borrowed_bytes, binary.subtype); + match OwnedOrBorrowedRawBsonVisitor::parse_map(&mut map)? { + MapParse::Leaf(bson) => { + match bson { + // Cases that need to distinguish owned and borrowed + OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Binary(b)) => { + self.append_borrowed_binary(b.bytes, b.subtype.into()); + Ok(ElementType::Binary) } - Cow::Owned(owned_bytes) => { - self.append_owned_binary(owned_bytes, binary.subtype); + OwnedOrBorrowedRawBson::Owned(RawBson::Binary(b)) => { + self.append_owned_binary(b.bytes, b.subtype.into()); + Ok(ElementType::Binary) } - } - - Ok(ElementType::Binary) - } - "$date" => { - let date: i64 = map.next_value()?; - self.buffer.append_bytes(&date.to_le_bytes()); - Ok(ElementType::DateTime) - } - "$timestamp" => { - let timestamp: TimestampBody = map.next_value()?; - self.buffer.append_bytes(×tamp.i.to_le_bytes()); - self.buffer.append_bytes(×tamp.t.to_le_bytes()); - Ok(ElementType::Timestamp) - } - "$minKey" => { - let _: i32 = map.next_value()?; - Ok(ElementType::MinKey) - } - "$maxKey" => { - let _: i32 = map.next_value()?; - Ok(ElementType::MaxKey) - } - "$code" => { - let code: CowStr = map.next_value()?; - if let Some(key) = map.next_key::()? { - let key = key.0.as_ref(); - if key == "$scope" { - let length_index = self.pad_document_length(); - self.append_string(code.0.as_ref()); - - let scope: RawDocumentBuf = map.next_value()?; - self.buffer.append_bytes(scope.as_bytes()); - - let length_bytes = - ((self.buffer.len() - length_index) as i32).to_le_bytes(); - self.buffer - .copy_from_slice(length_index..length_index + 4, &length_bytes); - - Ok(ElementType::JavaScriptCodeWithScope) - } else { - Err(SerdeError::unknown_field(key, &["$scope"])) + OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Document(doc)) => { + self.buffer.append_borrowed_bytes(doc.as_bytes()); + Ok(ElementType::EmbeddedDocument) + } + OwnedOrBorrowedRawBson::Borrowed(RawBsonRef::Array(arr)) => { + self.buffer.append_borrowed_bytes(arr.as_bytes()); + Ok(ElementType::Array) } - } else { - self.append_string(code.0.as_ref()); - Ok(ElementType::JavaScriptCode) + // Cases that don't + _ => match bson.as_ref() { + RawBsonRef::ObjectId(oid) => { + self.buffer.append_bytes(&oid.bytes()); + Ok(ElementType::ObjectId) + } + RawBsonRef::Symbol(s) => { + self.append_string(s); + Ok(ElementType::Symbol) + } + RawBsonRef::Decimal128(d) => { + self.buffer.append_bytes(&d.bytes); + Ok(ElementType::Decimal128) + } + RawBsonRef::RegularExpression(re) => { + self.append_cstring(re.pattern) + .map_err(SerdeError::custom)?; + self.append_cstring(re.options) + .map_err(SerdeError::custom)?; + Ok(ElementType::RegularExpression) + } + RawBsonRef::Undefined => Ok(ElementType::Undefined), + RawBsonRef::DateTime(dt) => { + self.buffer.append_bytes(&dt.as_le_bytes()); + Ok(ElementType::DateTime) + } + RawBsonRef::Timestamp(ts) => { + self.buffer.append_bytes(&ts.increment.to_le_bytes()); + self.buffer.append_bytes(&ts.time.to_le_bytes()); + Ok(ElementType::Timestamp) + } + RawBsonRef::MinKey => Ok(ElementType::MinKey), + RawBsonRef::MaxKey => Ok(ElementType::MaxKey), + RawBsonRef::JavaScriptCode(s) => { + self.append_string(s); + Ok(ElementType::JavaScriptCode) + } + RawBsonRef::JavaScriptCodeWithScope(jsc) => { + let length_index = self.pad_document_length(); + self.append_string(jsc.code); + self.buffer.append_bytes(jsc.scope.as_bytes()); + + let length_bytes = + ((self.buffer.len() - length_index) as i32).to_le_bytes(); + self.buffer + .copy_from_slice(length_index..length_index + 4, &length_bytes); + + Ok(ElementType::JavaScriptCodeWithScope) + } + RawBsonRef::DbPointer(dbp) => { + self.append_string(dbp.namespace); + self.buffer.append_bytes(&dbp.id.bytes()); + Ok(ElementType::DbPointer) + } + RawBsonRef::Double(d) => self.visit_f64(d), + RawBsonRef::String(s) => self.visit_str(s), + RawBsonRef::Boolean(b) => self.visit_bool(b), + RawBsonRef::Null => self.visit_none(), + RawBsonRef::Int32(i) => self.visit_i32(i), + RawBsonRef::Int64(i) => self.visit_i64(i), + // These are always borrowed and are handled + // at the top of the outer `match`. + RawBsonRef::Array(_) | RawBsonRef::Document(_) | RawBsonRef::Binary(_) => { + unreachable!() + } + }, } } - "$dbPointer" => { - let db_pointer: BorrowedDbPointerBody = map.next_value()?; - - self.append_string(db_pointer.ns.0.as_ref()); - self.buffer.append_bytes(&db_pointer.id.bytes()); - - Ok(ElementType::DbPointer) - } - RAW_DOCUMENT_NEWTYPE => { - let document_bytes: &[u8] = map.next_value()?; - self.buffer.append_borrowed_bytes(document_bytes); - Ok(ElementType::EmbeddedDocument) - } - RAW_ARRAY_NEWTYPE => { - let array_bytes: &[u8] = map.next_value()?; - self.buffer.append_borrowed_bytes(array_bytes); - Ok(ElementType::Array) - } - _ => { + MapParse::Aggregate(first_key) => { self.iterate_map(first_key, map)?; Ok(ElementType::EmbeddedDocument) }