From ae5c69ff42b8b94e2ea4c9cbe2905af19b3a2a0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Drtina?= Date: Thu, 18 Jul 2024 22:32:13 +0200 Subject: [PATCH] Macro for generating extensions (#1) --- .../tests/directory_scimple/users_endpoint.rs | 20 +-- scim-protocol/src/endpoint/users.rs | 19 +-- scim-protocol/src/protocol/extensions.rs | 52 ------- .../src/protocol/extensions/generate.rs | 140 ++++++++++++++++++ scim-protocol/src/protocol/extensions/mod.rs | 11 ++ .../rfc7643_section8/sec_8_1_minimal_user.rs | 2 +- .../rfc7643_section8/sec_8_2_full_user.rs | 2 +- .../tests/rfc7643_section8/sec_8_4_group.rs | 2 +- 8 files changed, 168 insertions(+), 80 deletions(-) delete mode 100644 scim-protocol/src/protocol/extensions.rs create mode 100644 scim-protocol/src/protocol/extensions/generate.rs create mode 100644 scim-protocol/src/protocol/extensions/mod.rs diff --git a/scim-client/tests/directory_scimple/users_endpoint.rs b/scim-client/tests/directory_scimple/users_endpoint.rs index cb8a450..a6c5231 100644 --- a/scim-client/tests/directory_scimple/users_endpoint.rs +++ b/scim-client/tests/directory_scimple/users_endpoint.rs @@ -1,8 +1,7 @@ -use scim_protocol::generate_endpoint; -use scim_protocol::protocol::Extensions; use scim_protocol::resource::enterprise_user::EnterpriseUser; use scim_protocol::resource::user::User; use scim_protocol::resource::ScimSchema; +use scim_protocol::{generate_endpoint, generate_extension}; generate_endpoint!( path = "/Users", @@ -13,17 +12,12 @@ generate_endpoint!( extensions = UserExtensions, ); -// TODO: generate even this -#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] -pub struct UserExtensions { - #[serde(rename = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User")] - pub enterprise_user: EnterpriseUser, - #[serde(rename = "urn:mem:params:scim:schemas:extension:LuckyNumberExtension")] - pub lucky_number: LuckyNumber, -} -impl Extensions for UserExtensions { - const SCHEMA: &'static [&'static str] = &[EnterpriseUser::SCHEMA, LuckyNumber::SCHEMA]; -} +generate_extension!( + extension UserExtensions { + enterprise_user: EnterpriseUser, + lucky_number: LuckyNumber, + } +); #[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[serde(rename_all = "camelCase")] diff --git a/scim-protocol/src/endpoint/users.rs b/scim-protocol/src/endpoint/users.rs index cd8cf00..7d0cc68 100644 --- a/scim-protocol/src/endpoint/users.rs +++ b/scim-protocol/src/endpoint/users.rs @@ -1,8 +1,7 @@ -use crate::generate_endpoint; -use crate::protocol::{Extensions, NoExtensions}; +use crate::protocol::NoExtensions; use crate::resource::enterprise_user::EnterpriseUser; use crate::resource::user::User; -use crate::resource::ScimSchema; +use crate::{generate_endpoint, generate_extension}; generate_endpoint!( path = "/Users", @@ -22,12 +21,8 @@ generate_endpoint!( extensions = EnterpriseUserExtensions, ); -// TODO: generate even this -#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] -pub struct EnterpriseUserExtensions { - #[serde(rename = "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User")] - pub enterprise_user: EnterpriseUser, -} -impl Extensions for EnterpriseUserExtensions { - const SCHEMA: &'static [&'static str] = &[EnterpriseUser::SCHEMA]; -} +generate_extension!( + extension EnterpriseUserExtensions { + enterprise_user: EnterpriseUser, + } +); diff --git a/scim-protocol/src/protocol/extensions.rs b/scim-protocol/src/protocol/extensions.rs deleted file mode 100644 index bcc23dc..0000000 --- a/scim-protocol/src/protocol/extensions.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::fmt::{self, Formatter}; - -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use serde_json::Value; - -pub trait Extensions { - const SCHEMA: &'static [&'static str]; -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct NoExtensions; -impl Extensions for NoExtensions { - const SCHEMA: &'static [&'static str] = &[]; -} - -impl<'de> Deserialize<'de> for NoExtensions { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Visitor; - - impl<'de> serde::de::Visitor<'de> for Visitor { - type Value = NoExtensions; - fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { - formatter.write_str("NoExtensions struct") - } - fn visit_unit(self) -> Result { - Ok(NoExtensions) - } - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - if map.next_entry::()?.is_some() { - Err(serde::de::Error::custom( - "Unexpected data, NoExtenstions struct has no data", - )) - } else { - Ok(NoExtensions) - } - } - } - deserializer.deserialize_map(Visitor) - } -} - -impl Serialize for NoExtensions { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_unit() - } -} diff --git a/scim-protocol/src/protocol/extensions/generate.rs b/scim-protocol/src/protocol/extensions/generate.rs new file mode 100644 index 0000000..f2acecb --- /dev/null +++ b/scim-protocol/src/protocol/extensions/generate.rs @@ -0,0 +1,140 @@ +#[macro_export] +macro_rules! generate_extension { + (extension $name:ident { + $($var:ident: $type:ty),* + $(,)? + }) => { + #[derive(Debug, Clone, PartialEq, Eq)] + pub struct $name { + $(pub $var: $type),* + } + impl $crate::protocol::Extensions for $name { + const SCHEMA: &'static [&'static str] = &[$( + <$type as $crate::resource::ScimSchema>::SCHEMA + ),*]; + } + + impl<'de> serde::Deserialize<'de> for $name { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[allow(non_camel_case_types)] + enum Field { + $($var),* + // Ignore, + } + #[doc(hidden)] + struct FieldVisitor; + impl<'de> serde::de::Visitor<'de> for FieldVisitor { + type Value = Field; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Formatter::write_str(formatter, "field identifier") + } + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + // TODO: should this be case insensitive check? + $(<$type as $crate::resource::ScimSchema>::SCHEMA => { + Ok(Field::$var) + })* + _ => { + return Err(E::unknown_field(value, FIELDS)); + } + //_ => Ok(Field::Ignore), + } + } + fn visit_bytes(self, value: &[u8]) -> Result + where + E: serde::de::Error, + { + match value { + // TODO: should this be case insensitive check? + $(_ if value == <$type as $crate::resource::ScimSchema>::SCHEMA.as_bytes() => { + Ok(Field::$var) + })* + _ => { + return Err(E::unknown_field(&String::from_utf8_lossy(value), FIELDS)); + } + //_ => Ok(Field::Ignore), + } + } + } + impl<'de> serde::Deserialize<'de> for Field { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + serde::Deserializer::deserialize_identifier(deserializer, FieldVisitor) + } + } + struct Visitor; + impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = $name; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Formatter::write_str(formatter, concat!("struct ", stringify!($name))) + } + + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + $(let mut $var: Option<$type> = None;)* + while let Some(key) = serde::de::MapAccess::next_key::(&mut map)? { + match key { + $(Field::$var => { + if Option::is_some(&$var) { + return Err(::duplicate_field( + <$type as $crate::resource::ScimSchema>::SCHEMA, + )); + } + $var = Some(serde::de::MapAccess::next_value::<$type>(&mut map)?); + })* + /*Field::Ignore => { + let _ = serde::de::MapAccess::next_value::(&mut map)?; + }*/ + } + } + $(let Some($var) = $var else { + return Err(::missing_field( + <$type as $crate::resource::ScimSchema>::SCHEMA, + )); + };)* + + Ok($name { + $($var),* + }) + } + } + const FIELDS: &'static [&'static str] = &[ + $(<$type as $crate::resource::ScimSchema>::SCHEMA),* + ]; + serde::Deserializer::deserialize_struct(deserializer, stringify!($name), FIELDS, Visitor) + } + } + + impl serde::Serialize for $name { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[allow(unused_mut)] + let mut state = serde::Serializer::serialize_struct( + serializer, + stringify!($name), + <$name as $crate::protocol::Extensions>::SCHEMA.len(), + )?; + $(serde::ser::SerializeStruct::serialize_field( + &mut state, + <$type as $crate::resource::ScimSchema>::SCHEMA, + &self.$var, + )?;)* + serde::ser::SerializeStruct::end(state) + } + } + }; +} diff --git a/scim-protocol/src/protocol/extensions/mod.rs b/scim-protocol/src/protocol/extensions/mod.rs new file mode 100644 index 0000000..2efd6c3 --- /dev/null +++ b/scim-protocol/src/protocol/extensions/mod.rs @@ -0,0 +1,11 @@ +mod generate; + +use crate::generate_extension; + +pub trait Extensions { + const SCHEMA: &'static [&'static str]; +} + +generate_extension!( + extension NoExtensions {} +); diff --git a/scim-protocol/tests/rfc7643_section8/sec_8_1_minimal_user.rs b/scim-protocol/tests/rfc7643_section8/sec_8_1_minimal_user.rs index f26f421..683c844 100644 --- a/scim-protocol/tests/rfc7643_section8/sec_8_1_minimal_user.rs +++ b/scim-protocol/tests/rfc7643_section8/sec_8_1_minimal_user.rs @@ -12,7 +12,7 @@ fn test_response() { external_id: None, meta: meta(), resource: user(), - extensions: NoExtensions, + extensions: NoExtensions {}, }; assert_eq!(expected, actual); } diff --git a/scim-protocol/tests/rfc7643_section8/sec_8_2_full_user.rs b/scim-protocol/tests/rfc7643_section8/sec_8_2_full_user.rs index 7e55bb6..40ea2e7 100644 --- a/scim-protocol/tests/rfc7643_section8/sec_8_2_full_user.rs +++ b/scim-protocol/tests/rfc7643_section8/sec_8_2_full_user.rs @@ -21,7 +21,7 @@ fn test_response() { external_id: Some("701984".to_string()), meta: meta(id), resource: user(), - extensions: NoExtensions, + extensions: NoExtensions {}, }; assert_eq!(expected, actual); } diff --git a/scim-protocol/tests/rfc7643_section8/sec_8_4_group.rs b/scim-protocol/tests/rfc7643_section8/sec_8_4_group.rs index d1485e9..0e40d63 100644 --- a/scim-protocol/tests/rfc7643_section8/sec_8_4_group.rs +++ b/scim-protocol/tests/rfc7643_section8/sec_8_4_group.rs @@ -47,6 +47,6 @@ fn expected() -> GroupResponse { }, ], }, - extensions: NoExtensions, + extensions: NoExtensions {}, } }