diff --git a/Cargo.lock b/Cargo.lock index dc94b2568..e87dc1ad2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1340,6 +1340,9 @@ name = "nonempty" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "303e8749c804ccd6ca3b428de7fe0d86cb86bc7606bc15291f100fd487960bb8" +dependencies = [ + "serde", +] [[package]] name = "normalize-line-endings" diff --git a/cedar-policy-core/Cargo.toml b/cedar-policy-core/Cargo.toml index 18dabf538..d85e79230 100644 --- a/cedar-policy-core/Cargo.toml +++ b/cedar-policy-core/Cargo.toml @@ -27,7 +27,7 @@ smol_str = { version = "0.3", features = ["serde"] } stacker = "0.1.15" arbitrary = { version = "1", features = ["derive"], optional = true } miette = { version = "7.5.0", features = ["serde"] } -nonempty = "0.10.0" +nonempty = { version = "0.10.0", features = ["serialize"] } educe = "0.6.0" # decimal extension requires regex diff --git a/cedar-policy-core/src/entities.rs b/cedar-policy-core/src/entities.rs index 37fc11da8..edf225c9d 100644 --- a/cedar-policy-core/src/entities.rs +++ b/cedar-policy-core/src/entities.rs @@ -2163,6 +2163,7 @@ mod schema_based_parsing_tests { use crate::extensions::Extensions; use crate::test_utils::*; use cool_asserts::assert_matches; + use nonempty::NonEmpty; use serde_json::json; use smol_str::SmolStr; use std::collections::HashSet; @@ -2253,6 +2254,9 @@ mod schema_based_parsing_tests { /// Mock schema impl for the `Employee` type used in most of these tests struct MockEmployeeDescription; impl EntityTypeDescription for MockEmployeeDescription { + fn enum_entity_eids(&self) -> Option> { + None + } fn entity_type(&self) -> EntityType { EntityType::from(Name::parse_unqualified_name("Employee").expect("valid")) } @@ -3502,6 +3506,9 @@ mod schema_based_parsing_tests { struct MockEmployeeDescription; impl EntityTypeDescription for MockEmployeeDescription { + fn enum_entity_eids(&self) -> Option> { + None + } fn entity_type(&self) -> EntityType { "XYZCorp::Employee".parse().expect("valid") } @@ -3630,6 +3637,109 @@ mod schema_based_parsing_tests { ); }); } + + #[test] + fn enumerated_entities() { + struct MockSchema; + struct StarTypeDescription; + impl EntityTypeDescription for StarTypeDescription { + fn entity_type(&self) -> EntityType { + "Star".parse().unwrap() + } + + fn attr_type(&self, _attr: &str) -> Option { + None + } + + fn tag_type(&self) -> Option { + None + } + + fn required_attrs<'s>(&'s self) -> Box + 's> { + Box::new(std::iter::empty()) + } + + fn allowed_parent_types(&self) -> Arc> { + Arc::new(HashSet::new()) + } + + fn open_attributes(&self) -> bool { + false + } + + fn enum_entity_eids(&self) -> Option> { + Some(nonempty::nonempty![Eid::new("🌎"), Eid::new("🌕"),]) + } + } + impl Schema for MockSchema { + type EntityTypeDescription = StarTypeDescription; + + type ActionEntityIterator = std::iter::Empty>; + + fn entity_type(&self, entity_type: &EntityType) -> Option { + if entity_type == &"Star".parse::().unwrap() { + Some(StarTypeDescription) + } else { + None + } + } + + fn action(&self, _action: &EntityUID) -> Option> { + None + } + + fn entity_types_with_basename<'a>( + &'a self, + basename: &'a UnreservedId, + ) -> Box + 'a> { + if basename == &"Star".parse::().unwrap() { + Box::new(std::iter::once("Star".parse::().unwrap())) + } else { + Box::new(std::iter::empty()) + } + } + + fn action_entities(&self) -> Self::ActionEntityIterator { + std::iter::empty() + } + } + + let eparser = EntityJsonParser::new( + Some(&MockSchema), + Extensions::none(), + TCComputation::ComputeNow, + ); + + assert_matches!( + eparser.from_json_value(serde_json::json!([ + { + "uid": { "type": "Star", "id": "🌎" }, + "attrs": {}, + "parents": [], + } + ])), + Ok(_) + ); + + let entitiesjson = serde_json::json!([ + { + "uid": { "type": "Star", "id": "🪐" }, + "attrs": {}, + "parents": [], + } + ]); + assert_matches!(eparser.from_json_value(entitiesjson.clone()), + Err(e) => { + expect_err( + &entitiesjson, + &miette::Report::new(e), + &ExpectedErrorMessageBuilder::error("entity does not conform to the schema") + .source(r#"entity `Star::"🪐"` is of an enumerated entity type, but `"🪐"` is not declared as a valid eid"#) + .help(r#"valid entity eids: "🌎", "🌕""#) + .build() + ); + }); + } } #[cfg(feature = "protobufs")] diff --git a/cedar-policy-core/src/entities/conformance.rs b/cedar-policy-core/src/entities/conformance.rs index 528390ca7..92092e875 100644 --- a/cedar-policy-core/src/entities/conformance.rs +++ b/cedar-policy-core/src/entities/conformance.rs @@ -17,6 +17,7 @@ use std::collections::BTreeMap; use super::{json::err::TypeMismatchError, EntityTypeDescription, Schema, SchemaType}; +use super::{Eid, EntityUID, Literal}; use crate::ast::{ BorrowedRestrictedExpr, Entity, PartialValue, PartialValueToRestrictedExprError, RestrictedExpr, }; @@ -27,7 +28,7 @@ use smol_str::SmolStr; use thiserror::Error; pub mod err; -use err::{EntitySchemaConformanceError, UnexpectedEntityTypeError}; +use err::{EntitySchemaConformanceError, InvalidEnumEntityError, UnexpectedEntityTypeError}; /// Struct used to check whether entities conform to a schema #[derive(Debug, Clone)] @@ -61,6 +62,8 @@ impl<'a, S: Schema> EntitySchemaConformanceChecker<'a, S> { )); } } else { + validate_euid(self.schema, uid) + .map_err(|e| EntitySchemaConformanceError::InvalidEnumEntity(e.into()))?; let schema_etype = self.schema.entity_type(etype).ok_or_else(|| { let suggested_types = self .schema @@ -120,10 +123,14 @@ impl<'a, S: Schema> EntitySchemaConformanceChecker<'a, S> { } } } + validate_euids_in_partial_value(self.schema, val) + .map_err(|e| EntitySchemaConformanceError::InvalidEnumEntity(e.into()))?; } // For each ancestor that actually appears in `entity`, ensure the // ancestor type is allowed by the schema for ancestor_euid in entity.ancestors() { + validate_euid(self.schema, ancestor_euid) + .map_err(|e| EntitySchemaConformanceError::InvalidEnumEntity(e.into()))?; let ancestor_type = ancestor_euid.entity_type(); if schema_etype.allowed_parent_types().contains(ancestor_type) { // note that `allowed_parent_types()` was transitively @@ -137,11 +144,70 @@ impl<'a, S: Schema> EntitySchemaConformanceChecker<'a, S> { )); } } + + for (_, val) in entity.tags() { + validate_euids_in_partial_value(self.schema, val) + .map_err(|e| EntitySchemaConformanceError::InvalidEnumEntity(e.into()))?; + } } Ok(()) } } +/// Return an [`InvalidEnumEntityError`] if `uid`'s eid is not among valid `choices` +pub fn is_valid_enumerated_entity( + choices: &[Eid], + uid: &EntityUID, +) -> Result<(), InvalidEnumEntityError> { + choices + .iter() + .find(|id| uid.eid() == *id) + .ok_or(InvalidEnumEntityError { + uid: uid.clone(), + choices: choices.to_vec(), + }) + .map(|_| ()) +} + +/// Validate if `euid` is valid, provided that it's of an enumerated type +pub(crate) fn validate_euid( + schema: &impl Schema, + euid: &EntityUID, +) -> Result<(), InvalidEnumEntityError> { + if let Some(desc) = schema.entity_type(euid.entity_type()) { + if let Some(choices) = desc.enum_entity_eids() { + is_valid_enumerated_entity(&Vec::from(choices), euid)?; + } + } + Ok(()) +} + +fn validate_euids_in_subexpressions<'a>( + exprs: impl Iterator, + schema: &impl Schema, +) -> std::result::Result<(), InvalidEnumEntityError> { + exprs + .map(|e| match e.expr_kind() { + ExprKind::Lit(Literal::EntityUID(euid)) => validate_euid(schema, &euid), + _ => Ok(()), + }) + .collect::>() +} + +/// Validate if enumerated entities in `val` are valid +pub fn validate_euids_in_partial_value( + schema: &impl Schema, + val: &PartialValue, +) -> Result<(), InvalidEnumEntityError> { + match val { + PartialValue::Value(val) => validate_euids_in_subexpressions( + RestrictedExpr::from(val.clone()).subexpressions(), + schema, + ), + PartialValue::Residual(e) => validate_euids_in_subexpressions(e.subexpressions(), schema), + } +} + /// Check whether the given `PartialValue` typechecks with the given `SchemaType`. /// If the typecheck passes, return `Ok(())`. /// If the typecheck fails, return an appropriate `Err`. diff --git a/cedar-policy-core/src/entities/conformance/err.rs b/cedar-policy-core/src/entities/conformance/err.rs index 666a611a3..cea1eff5c 100644 --- a/cedar-policy-core/src/entities/conformance/err.rs +++ b/cedar-policy-core/src/entities/conformance/err.rs @@ -15,8 +15,10 @@ */ //! This module cotnains errors around entities not conforming to schemas use super::TypeMismatchError; -use crate::ast::{EntityType, EntityUID}; +use crate::ast::{Eid, EntityType, EntityUID}; use crate::extensions::ExtensionFunctionLookupError; +use crate::impl_diagnostic_from_method_on_field; +use itertools::Itertools; use miette::Diagnostic; use smol_str::SmolStr; use thiserror::Error; @@ -70,6 +72,10 @@ pub enum EntitySchemaConformanceError { #[error(transparent)] #[diagnostic(transparent)] ExtensionFunctionLookup(ExtensionFunctionLookup), + /// Returned when an entity is of an enumerated entity type but has invalid EID + #[error(transparent)] + #[diagnostic(transparent)] + InvalidEnumEntity(#[from] InvalidEnumEntity), } impl EntitySchemaConformanceError { @@ -277,3 +283,45 @@ impl Diagnostic for UnexpectedEntityTypeError { } } } + +/// Returned when an entity is of an enumerated entity type but has invalid EID +// +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Error, Diagnostic)] +#[error(transparent)] +#[diagnostic(transparent)] +pub struct InvalidEnumEntity { + err: InvalidEnumEntityError, +} + +impl From for InvalidEnumEntity { + fn from(value: InvalidEnumEntityError) -> Self { + Self { err: value } + } +} + +/// Returned when an entity is of an enumerated entity type but has invalid EID +#[derive(Debug, Error, Clone, PartialEq, Eq, Hash)] +#[error("entity `{uid}` is of an enumerated entity type, but `\"{}\"` is not declared as a valid eid", uid.eid().escaped())] +pub struct InvalidEnumEntityError { + /// Entity where the error occurred + pub uid: EntityUID, + /// Name of the attribute where the error occurred + pub choices: Vec, +} + +impl Diagnostic for InvalidEnumEntityError { + impl_diagnostic_from_method_on_field!(uid, loc); + + fn help<'a>(&'a self) -> Option> { + Some(Box::new(format!( + "valid entity eids: {}", + self.choices + .iter() + .map(|e| format!("\"{}\"", e.escaped())) + .join(", ") + ))) + } +} diff --git a/cedar-policy-core/src/entities/json/schema.rs b/cedar-policy-core/src/entities/json/schema.rs index e0a08a38b..42cd2a75a 100644 --- a/cedar-policy-core/src/entities/json/schema.rs +++ b/cedar-policy-core/src/entities/json/schema.rs @@ -15,8 +15,9 @@ */ use super::SchemaType; -use crate::ast::{Entity, EntityType, EntityUID}; +use crate::ast::{Eid, Entity, EntityType, EntityUID}; use crate::entities::{Name, UnreservedId}; +use nonempty::NonEmpty; use smol_str::SmolStr; use std::collections::HashSet; use std::sync::Arc; @@ -137,6 +138,10 @@ pub trait EntityTypeDescription { /// May entities with this type have attributes other than those specified /// in the schema fn open_attributes(&self) -> bool; + + /// Return valid EID choices if the entity type is enumerated otherwise + /// return `None` + fn enum_entity_eids(&self) -> Option>; } /// Simple type that implements `EntityTypeDescription` by expecting no @@ -165,6 +170,9 @@ impl EntityTypeDescription for NullEntityTypeDescription { fn open_attributes(&self) -> bool { false } + fn enum_entity_eids(&self) -> Option> { + None + } } impl NullEntityTypeDescription { /// Create a new [`NullEntityTypeDescription`] for the given entity typename diff --git a/cedar-policy-validator/protobuf_schema/Validator.proto b/cedar-policy-validator/protobuf_schema/Validator.proto index fef7e3eac..21eb5cd3b 100644 --- a/cedar-policy-validator/protobuf_schema/Validator.proto +++ b/cedar-policy-validator/protobuf_schema/Validator.proto @@ -41,6 +41,7 @@ message ValidatorEntityType { Attributes attributes = 3; OpenTag open_attributes = 4; Tag tags = 5; + repeated string enums = 6; } message ValidatorActionId { diff --git a/cedar-policy-validator/src/cedar_schema/ast.rs b/cedar-policy-validator/src/cedar_schema/ast.rs index 02d1d632d..9262fa59f 100644 --- a/cedar-policy-validator/src/cedar_schema/ast.rs +++ b/cedar-policy-validator/src/cedar_schema/ast.rs @@ -249,9 +249,24 @@ impl Decl for TypeDecl { } } +#[derive(Debug, Clone)] +pub enum EntityDecl { + Standard(StandardEntityDecl), + Enum(EnumEntityDecl), +} + +impl EntityDecl { + pub fn names(&self) -> impl Iterator> + '_ { + match self { + Self::Enum(d) => d.names.iter(), + Self::Standard(d) => d.names.iter(), + } + } +} + /// Declaration of an entity type #[derive(Debug, Clone)] -pub struct EntityDecl { +pub struct StandardEntityDecl { /// Entity Type Names bound by this declaration. /// More than one name can be bound if they have the same definition, for convenience pub names: NonEmpty>, @@ -263,6 +278,13 @@ pub struct EntityDecl { pub tags: Option>, } +/// Declaration of an entity type +#[derive(Debug, Clone)] +pub struct EnumEntityDecl { + pub names: NonEmpty>, + pub choices: NonEmpty>, +} + /// Type definitions #[derive(Debug, Clone)] pub enum Type { diff --git a/cedar-policy-validator/src/cedar_schema/err.rs b/cedar-policy-validator/src/cedar_schema/err.rs index d5b65b2bf..fbd283251 100644 --- a/cedar-policy-validator/src/cedar_schema/err.rs +++ b/cedar-policy-validator/src/cedar_schema/err.rs @@ -89,6 +89,7 @@ lazy_static! { ("SET", "`Set`"), ("IDENTIFIER", "identifier"), ("TAGS", "`tags`"), + ("ENUM", "`enum`"), ]), impossible_tokens: HashSet::new(), special_identifier_tokens: HashSet::from([ @@ -106,6 +107,7 @@ lazy_static! { "LONG", "STRING", "BOOL", + "ENUM", ]), identifier_sentinel: "IDENTIFIER", first_set_identifier_tokens: HashSet::from(["SET"]), diff --git a/cedar-policy-validator/src/cedar_schema/fmt.rs b/cedar-policy-validator/src/cedar_schema/fmt.rs index cbb182c16..8a108c70f 100644 --- a/cedar-policy-validator/src/cedar_schema/fmt.rs +++ b/cedar-policy-validator/src/cedar_schema/fmt.rs @@ -107,6 +107,22 @@ fn fmt_non_empty_slice( } impl Display for json_schema::EntityType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + json_schema::EntityTypeKind::Standard(ty) => ty.fmt(f), + json_schema::EntityTypeKind::Enum { choices } => write!( + f, + "[{}]", + choices + .iter() + .map(|e| format!("\"{}\"", e.escape_debug())) + .join(", ") + ), + } + } +} + +impl Display for json_schema::StandardEntityType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(non_empty) = self.member_of_types.split_first() { write!(f, " in ")?; diff --git a/cedar-policy-validator/src/cedar_schema/grammar.lalrpop b/cedar-policy-validator/src/cedar_schema/grammar.lalrpop index 99af8ab18..8d56b616a 100644 --- a/cedar-policy-validator/src/cedar_schema/grammar.lalrpop +++ b/cedar-policy-validator/src/cedar_schema/grammar.lalrpop @@ -24,6 +24,8 @@ use smol_str::ToSmolStr; use crate::cedar_schema::ast::{ Path, EntityDecl, + StandardEntityDecl, + EnumEntityDecl, Declaration, Namespace, Schema as ASchema, @@ -76,6 +78,7 @@ match { "Long" => LONG, "String" => STRING, "Bool" => BOOL, + "enum" => ENUM, // data input r"[_a-zA-Z][_a-zA-Z0-9]*" => IDENTIFIER, @@ -128,12 +131,16 @@ Decl: Node = { // Entity := 'entity' Idents ['in' EntTypes] [['='] RecType] ';' Entity: Node = { ENTITY )?> "}")?> )?> ";" - => Node::with_source_loc(Declaration::Entity(EntityDecl { + => Node::with_source_loc(Declaration::Entity(EntityDecl::Standard(StandardEntityDecl { names: ets, member_of_types: ps.unwrap_or_default(), attrs: Node::with_source_loc(ds.map(|ds| ds.unwrap_or_default()).unwrap_or_default(), Loc::new(l2..r2, Arc::clone(src))), tags: ts, - }), Loc::new(l1..r1, Arc::clone(src))), + })), Loc::new(l1..r1, Arc::clone(src))), + ENTITY ENUM "[" "]" ";" => Node::with_source_loc(Declaration::Entity(EntityDecl::Enum(EnumEntityDecl { + names: ets, + choices, + })), Loc::new(l..r, Arc::clone(src))), } // Action := 'action' Names ['in' QualNameOrNames] @@ -262,6 +269,8 @@ AnyIdent: Node = { => Node::with_source_loc("type".parse().unwrap(), Loc::new(l..r, Arc::clone(src))), IN => Node::with_source_loc("in".parse().unwrap(), Loc::new(l..r, Arc::clone(src))), + ENUM + => Node::with_source_loc("enum".parse().unwrap(), Loc::new(l..r, Arc::clone(src))), => Node::with_source_loc(i.parse().unwrap(), Loc::new(l..r, Arc::clone(src))), } @@ -296,9 +305,9 @@ PathInline: Path = { NonEmptyComma: NonEmpty = { => NonEmpty::singleton(e), - ",")+> => { - let mut all = NonEmpty::singleton(e); - all.append(&mut es); + ",")+> => { + let mut all = NonEmpty::from_vec(es).unwrap(); + all.push(e); all }, } @@ -312,6 +321,9 @@ Names: NonEmpty> = NonEmptyComma; // Qualnames := Qualname {',' Qualname } QualNames : NonEmpty> = NonEmptyComma; +// STRs := STR {',' STR} +STRs: NonEmpty> = NonEmptyComma; + PrincipalOrResource: Node = { PRINCIPAL => Node::with_source_loc(PR::Principal, Loc::new(l..r, Arc::clone(src))), RESOURCE => Node::with_source_loc(PR::Resource, Loc::new(l..r, Arc::clone(src))), diff --git a/cedar-policy-validator/src/cedar_schema/test.rs b/cedar-policy-validator/src/cedar_schema/test.rs index 6607248b5..04820443d 100644 --- a/cedar-policy-validator/src/cedar_schema/test.rs +++ b/cedar-policy-validator/src/cedar_schema/test.rs @@ -40,7 +40,7 @@ mod demo_tests { ast::PR, err::{ToJsonSchemaError, NO_PR_HELP_MSG}, }, - json_schema, + json_schema::{self, EntityType, EntityTypeKind}, schema::test::utils::collect_warnings, CedarSchemaError, RawName, }; @@ -422,15 +422,14 @@ namespace Baz {action "Foo" appliesTo { let namespace = json_schema::NamespaceDefinition::new( [( "a".parse().unwrap(), - json_schema::EntityType:: { + json_schema::StandardEntityType:: { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], - [( + BTreeMap::from([( "j".to_smolstr(), json_schema::ActionType:: { attributes: None, @@ -443,7 +442,7 @@ namespace Baz {action "Foo" appliesTo { annotations: Annotations::new(), loc: None, }, - )], + )]), ); let fragment = json_schema::Fragment(BTreeMap::from([(None, namespace)])); let src = fragment.to_cedarschema().unwrap(); @@ -527,30 +526,30 @@ namespace Baz {action "Foo" appliesTo { .get(&Some("GitHub".parse().unwrap())) .expect("`Github` name space did not exist"); // User - let user = github + assert_matches!(github .entity_types .get(&"User".parse().unwrap()) - .expect("No `User`"); + .expect("No `User`"), EntityType { kind: EntityTypeKind::Standard(user), ..} => { assert_empty_record(user); assert_eq!( &user.member_of_types, &vec!["UserGroup".parse().unwrap(), "Team".parse().unwrap()] - ); + );}); // UserGroup - let usergroup = &github + assert_matches!(github .entity_types .get(&"UserGroup".parse().unwrap()) - .expect("No `UserGroup`"); + .expect("No `UserGroup`"), EntityType { kind: EntityTypeKind::Standard(usergroup), ..} => { assert_empty_record(usergroup); assert_eq!( &usergroup.member_of_types, &vec!["UserGroup".parse().unwrap()] - ); + );}); // Repository - let repo = github + assert_matches!(github .entity_types .get(&"Repository".parse().unwrap()) - .expect("No `Repository`"); + .expect("No `Repository`"), EntityType {kind: EntityTypeKind::Standard(repo), ..} => { assert!(repo.member_of_types.is_empty()); let groups = ["readers", "writers", "triagers", "admins", "maintainers"]; for group in groups { @@ -565,11 +564,11 @@ namespace Baz {action "Foo" appliesTo { let attribute = attributes.get(group).expect("No attribute `{group}`"); assert_has_type(attribute, &expected); }); - } - let issue = github + }}); + assert_matches!(github .entity_types .get(&"Issue".parse().unwrap()) - .expect("No `Issue`"); + .expect("No `Issue`"), EntityType {kind: EntityTypeKind::Standard(issue), .. } => { assert!(issue.member_of_types.is_empty()); assert_matches!(&issue.shape, json_schema::AttributesOrContext(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(json_schema::RecordType { attributes, @@ -589,11 +588,11 @@ namespace Baz {action "Foo" appliesTo { type_name: "User".parse().unwrap(), }, loc: None }, ); - }); - let org = github + });}); + assert_matches!(github .entity_types .get(&"Org".parse().unwrap()) - .expect("No `Org`"); + .expect("No `Org`"), EntityType { kind: EntityTypeKind::Standard(org), .. } => { assert!(org.member_of_types.is_empty()); let groups = ["members", "owners", "memberOfTypes"]; for group in groups { @@ -607,7 +606,7 @@ namespace Baz {action "Foo" appliesTo { let attribute = attributes.get(group).expect("No attribute `{group}`"); assert_has_type(attribute, &expected); }); - } + }}); } #[track_caller] @@ -620,7 +619,7 @@ namespace Baz {action "Foo" appliesTo { } #[track_caller] - fn assert_empty_record(etyp: &json_schema::EntityType) { + fn assert_empty_record(etyp: &json_schema::StandardEntityType) { assert!(etyp.shape.is_empty_record()); } @@ -655,10 +654,10 @@ namespace Baz {action "Foo" appliesTo { .0 .get(&Some("DocCloud".parse().unwrap())) .expect("No `DocCloud` namespace"); - let user = doccloud + assert_matches!(doccloud .entity_types .get(&"User".parse().unwrap()) - .expect("No `User`"); + .expect("No `User`"), EntityType {kind: EntityTypeKind::Standard(user), ..} => { assert_eq!(&user.member_of_types, &vec!["Group".parse().unwrap()]); assert_matches!(&user.shape, json_schema::AttributesOrContext(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(json_schema::RecordType { attributes, @@ -678,11 +677,11 @@ namespace Baz {action "Foo" appliesTo { }, loc: None }), // we do expect a `loc`, but `assert_has_type()` will ignore the mismatch in presence of `loc`. We have separate tests for the correctness of `loc`s coming from the Cedar schema syntax in a test module called `preserves_source_locations`. }, loc: None }, ); - }); - let group = doccloud + });}); + assert_matches!(doccloud .entity_types .get(&"Group".parse().unwrap()) - .expect("No `Group`"); + .expect("No `Group`"), EntityType { kind: EntityTypeKind::Standard(group), .. } => { assert_eq!( &group.member_of_types, &vec!["DocumentShare".parse().unwrap()] @@ -697,11 +696,11 @@ namespace Baz {action "Foo" appliesTo { type_name: "User".parse().unwrap(), }, loc: None }, ); - }); - let document = doccloud + });}); + assert_matches!(doccloud .entity_types .get(&"Document".parse().unwrap()) - .expect("No `Group`"); + .expect("No `Group`"), EntityType { kind: EntityTypeKind::Standard(document), ..} => { assert!(document.member_of_types.is_empty()); assert_matches!(&document.shape, json_schema::AttributesOrContext(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(json_schema::RecordType { attributes, @@ -743,30 +742,33 @@ namespace Baz {action "Foo" appliesTo { type_name: "DocumentShare".parse().unwrap(), }, loc: None }, ); - }); - let document_share = doccloud + });}); + assert_matches!(doccloud .entity_types .get(&"DocumentShare".parse().unwrap()) - .expect("No `DocumentShare`"); + .expect("No `DocumentShare`"), EntityType { kind: EntityTypeKind::Standard(document_share), ..} => { assert!(document_share.member_of_types.is_empty()); assert_empty_record(document_share); + }); - let public = doccloud - .entity_types - .get(&"Public".parse().unwrap()) - .expect("No `Public`"); - assert_eq!( - &public.member_of_types, - &vec!["DocumentShare".parse().unwrap()] - ); - assert_empty_record(public); + assert_matches!(doccloud + .entity_types + .get(&"Public".parse().unwrap()) + .expect("No `Public`"), EntityType { kind: EntityTypeKind::Standard(public), ..} => { + assert_eq!( + &public.member_of_types, + &vec!["DocumentShare".parse().unwrap()] + ); + assert_empty_record(public); + }); - let drive = doccloud + assert_matches!(doccloud .entity_types .get(&"Drive".parse().unwrap()) - .expect("No `Drive`"); + .expect("No `Drive`"), EntityType { kind: EntityTypeKind::Standard(drive), ..} => { assert!(drive.member_of_types.is_empty()); assert_empty_record(drive); + }); } #[test] @@ -871,21 +873,21 @@ namespace Baz {action "Foo" appliesTo { json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); assert_eq!(warnings.collect::>(), vec![]); let service = fragment.0.get(&Some("Service".parse().unwrap())).unwrap(); - let resource = &service + assert_matches!(service .entity_types .get(&"Resource".parse().unwrap()) - .unwrap(); + .unwrap(), EntityType { kind: EntityTypeKind::Standard(resource), ..} => { assert_matches!(&resource.shape, json_schema::AttributesOrContext(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes: false, - }), loc: Some(_) }) => { + }), ..}) => { assert_matches!(attributes.get("tag"), Some(json_schema::TypeOfAttribute { ty, required: true, .. }) => { assert_matches!(&ty, json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc) } => { assert_eq!(type_name, &"AWS::Tag".parse().unwrap()); assert_matches!(loc.snippet(), Some("AWS::Tag")); }); }); - }); + });}); } #[test] @@ -918,7 +920,10 @@ namespace Baz {action "Foo" appliesTo { } mod parser_tests { - use crate::cedar_schema::parser::parse_schema; + use crate::cedar_schema::{ + ast::{Annotated, Declaration, EntityDecl, EnumEntityDecl, Namespace}, + parser::parse_schema, + }; use cool_asserts::assert_matches; #[test] @@ -1141,6 +1146,65 @@ mod parser_tests { ); assert_matches!(res, Ok(_)); } + + #[test] + fn enumerated_entity_types() { + let res = parse_schema( + r#" + entity Application enum [ "TinyTodo" ]; + entity User in [ Application ]; + "#, + ); + assert_matches!(res, Ok(ns) => { + assert_matches!(&ns, [Annotated {data: Namespace { decls, ..}, ..}, ..] => { + assert_matches!(decls, [Annotated { data, .. }] => { + assert_matches!(&data.node, Declaration::Entity(EntityDecl::Enum(EnumEntityDecl { choices, ..})) => { + assert_eq!(choices.clone().map(|n| n.node), nonempty::NonEmpty::singleton("TinyTodo".into())); + }); + }); + }); + }); + let res = parse_schema( + r#" + entity Application enum [ "TinyTodo", "GitHub", "DocumentCloud" ]; + entity User in [ Application ]; + "#, + ); + assert_matches!(res, Ok(ns) => { + assert_matches!(&ns, [Annotated {data: Namespace { decls, ..}, ..}, ..] => { + assert_matches!(decls, [Annotated { data, .. }] => { + assert_matches!(&data.node, Declaration::Entity(EntityDecl::Enum(EnumEntityDecl { choices, ..})) => { + assert_eq!(choices.clone().map(|n| n.node), nonempty::nonempty!["TinyTodo".into(), "GitHub".into(), "DocumentCloud".into()]); + }); + }); + }); + }); + let res = parse_schema( + r#" + entity enum enum [ "enum" ]; + "#, + ); + assert_matches!(res, Ok(ns) => { + assert_matches!(&ns, [Annotated {data: Namespace { decls, ..}, ..}] => { + assert_matches!(decls, [Annotated { data, .. }] => { + assert_matches!(&data.node, Declaration::Entity(EntityDecl::Enum(EnumEntityDecl { choices, ..})) => { + assert_eq!(choices.clone().map(|n| n.node), nonempty::NonEmpty::singleton("enum".into())); + }); + }); + }); + }); + + let res = parse_schema( + r#" + entity Application enum [ ]; + entity User in [ Application ]; + "#, + ); + // Maybe we want a better error message here + assert_matches!(res, Err(errs) => { + assert_eq!(errs.to_string(), "unexpected token `]`"); + }); + } } mod translator_tests { @@ -1150,6 +1214,7 @@ mod translator_tests { use cedar_policy_core::FromNormalizedStr; use cool_asserts::assert_matches; + use crate::json_schema::{EntityType, EntityTypeKind}; use crate::{ cedar_schema::{ err::ToJsonSchemaError, parser::parse_schema, @@ -1407,11 +1472,11 @@ mod translator_tests { let (frag, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let demo = frag.0.get(&Some("Demo".parse().unwrap())).unwrap(); - let user = &demo.entity_types.get(&"User".parse().unwrap()).unwrap(); - assert_matches!(&user.shape, json_schema::AttributesOrContext(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(json_schema::RecordType { + assert_matches!(demo.entity_types.get(&"User".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(user), ..} => { + assert_matches!(&user.shape, json_schema::AttributesOrContext(json_schema::Type::Type{ ty: json_schema::TypeVariant::Record(json_schema::RecordType { attributes, additional_attributes: false, - }), loc: Some(_) }) => { + }), ..}) => { assert_matches!(attributes.get("name"), Some(json_schema::TypeOfAttribute { ty, required: true, .. }) => { assert_matches!(ty, json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(_) } => { assert_eq!(&type_name.to_string(), "id"); @@ -1422,7 +1487,7 @@ mod translator_tests { assert_eq!(&type_name.to_string(), "email_address"); }); }); - }); + });}); assert_matches!(ValidatorSchema::try_from(frag), Err(e) => { expect_err( src, @@ -1486,8 +1551,8 @@ mod translator_tests { validator_schema .get_entity_type(&"A::B".parse().unwrap()) .unwrap() - .attributes - .attrs["foo"] + .attr("foo") + .unwrap() .attr_type, Type::EntityOrRecord(EntityRecordKind::Entity(EntityLUB::single_entity( "X::Z".parse().unwrap() @@ -1505,8 +1570,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["namespace".parse().unwrap()]); + }); } #[test] @@ -1536,8 +1602,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["Set".parse().unwrap()]); + }); } #[test] @@ -1550,8 +1617,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["appliesTo".parse().unwrap()]); + }); } #[test] @@ -1564,8 +1632,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["principal".parse().unwrap()]); + }); } #[test] @@ -1578,8 +1647,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["resource".parse().unwrap()]); + }); } #[test] @@ -1592,8 +1662,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["action".parse().unwrap()]); + }); } #[test] @@ -1606,8 +1677,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); - assert_eq!(foo.member_of_types, vec!["context".parse().unwrap()]); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { + assert_eq!(foo.member_of_types, vec!["context".parse().unwrap()]); + }); } #[test] @@ -1620,8 +1692,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); - assert_eq!(foo.member_of_types, vec!["attributes".parse().unwrap()]); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { + assert_eq!(foo.member_of_types, vec!["attributes".parse().unwrap()]); + }); } #[test] @@ -1634,8 +1707,9 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); - assert_eq!(foo.member_of_types, vec!["Bool".parse().unwrap()]); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { + assert_eq!(foo.member_of_types, vec!["Bool".parse().unwrap()]); + }); } #[test] @@ -1648,8 +1722,8 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); - assert_eq!(foo.member_of_types, vec!["Long".parse().unwrap()]); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["Long".parse().unwrap()]); + }); } #[test] @@ -1662,8 +1736,8 @@ mod translator_tests { let (schema, _) = json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); let ns = schema.0.get(&None).unwrap(); - let foo = ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(); - assert_eq!(foo.member_of_types, vec!["String".parse().unwrap()]); + assert_matches!(ns.entity_types.get(&"Foo".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(foo), ..} => { assert_eq!(foo.member_of_types, vec!["String".parse().unwrap()]); + }); } #[test] @@ -2250,6 +2324,31 @@ mod translator_tests { }), ); } + + #[test] + fn enumerated_entity_types() { + let src = r#" + entity Fruits enum ["🍍", "🥭", "🥝"]; + "#; + + let (schema, _) = + json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); + let ns = schema.0.get(&None).unwrap(); + assert_matches!(ns.entity_types.get(&"Fruits".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Enum { choices }, ..} => { + assert_eq!(Vec::from(choices.clone()), ["🍍", "🥭", "🥝"]); + }); + + let src = r#" + entity enum enum ["enum"]; + "#; + + let (schema, _) = + json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available()).unwrap(); + let ns = schema.0.get(&None).unwrap(); + assert_matches!(ns.entity_types.get(&"enum".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Enum { choices }, ..} => { + assert_eq!(Vec::from(choices.clone()), ["enum"]); + }); + } } mod common_type_references { @@ -2565,7 +2664,7 @@ mod common_type_references { /// Tests involving entity tags (RFC 82) mod entity_tags { - use crate::json_schema; + use crate::json_schema::{self, EntityType, EntityTypeKind}; use crate::schema::test::utils::collect_warnings; use cedar_policy_core::extensions::Extensions; use cool_asserts::assert_matches; @@ -2575,38 +2674,39 @@ mod entity_tags { let src = "entity E;"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { assert_matches!(&entity_type.tags, None); + }); }); let src = "entity E tags String;"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); - assert_matches!(&entity_type.tags, Some(json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc) }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { + assert_matches!(&entity_type.tags, Some(json_schema::Type::Type{ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc)}) => { assert_eq!(&format!("{type_name}"), "String"); assert_matches!(loc.snippet(), Some("String")); }); - }); + });}); let src = "entity E tags Set;"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); - assert_matches!(&entity_type.tags, Some(json_schema::Type::Type { ty: json_schema::TypeVariant::Set { element }, loc: Some(set_loc) }) => { - assert_matches!(&**element, json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(elt_loc) } => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { + assert_matches!(&entity_type.tags, Some(json_schema::Type::Type{ ty: json_schema::TypeVariant::Set { element }, loc: Some(set_loc)}) => { + assert_matches!(&**element, json_schema::Type::Type{ ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(elt_loc)} => { assert_eq!(&format!("{type_name}"), "String"); assert_matches!(set_loc.snippet(), Some("Set")); assert_matches!(elt_loc.snippet(), Some("String")); }); }); - }); + });}); let src = "entity E { foo: String } tags { foo: String };"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); - assert_matches!(&entity_type.tags, Some(json_schema::Type::Type { ty: json_schema::TypeVariant::Record(rty), loc: Some(rec_loc) }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { + assert_matches!(&entity_type.tags, Some(json_schema::Type::Type{ ty: json_schema::TypeVariant::Record(rty), loc: Some(rec_loc)}) => { assert_matches!(rty.attributes.get("foo"), Some(json_schema::TypeOfAttribute { ty, required, .. }) => { assert_matches!(ty, json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(foo_loc) } => { assert_eq!(&format!("{type_name}"), "String"); @@ -2616,27 +2716,27 @@ mod entity_tags { assert!(*required); }); }); - }); + });}); let src = "type T = String; entity E tags T;"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); - assert_matches!(&entity_type.tags, Some(json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc) }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { + assert_matches!(&entity_type.tags, Some(json_schema::Type::Type{ ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc)}) => { assert_eq!(&format!("{type_name}"), "T"); assert_matches!(loc.snippet(), Some("T")); }); - }); + });}); let src = "entity E tags E;"; assert_matches!(collect_warnings(json_schema::Fragment::from_cedarschema_str(src, Extensions::all_available())), Ok((frag, warnings)) => { assert!(warnings.is_empty()); - let entity_type = frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(); - assert_matches!(&entity_type.tags, Some(json_schema::Type::Type { ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc) }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"E".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(entity_type), ..} => { + assert_matches!(&entity_type.tags, Some(json_schema::Type::Type{ ty: json_schema::TypeVariant::EntityOrCommon { type_name }, loc: Some(loc)}) => { assert_eq!(&format!("{type_name}"), "E"); assert_matches!(loc.snippet(), Some("E")); }); - }); + });}); } } diff --git a/cedar-policy-validator/src/cedar_schema/to_json_schema.rs b/cedar-policy-validator/src/cedar_schema/to_json_schema.rs index ac40274fd..858cabf05 100644 --- a/cedar-policy-validator/src/cedar_schema/to_json_schema.rs +++ b/cedar-policy-validator/src/cedar_schema/to_json_schema.rs @@ -362,25 +362,33 @@ fn convert_entity_decl( impl Iterator)>, ToJsonSchemaErrors, > { - // First build up the defined entity type + let names: Vec> = e.data.node.names().cloned().collect(); let etype = json_schema::EntityType { - member_of_types: e - .data - .node - .member_of_types - .into_iter() - .map(RawName::from) - .collect(), - shape: convert_attr_decls(e.data.node.attrs), - tags: e.data.node.tags.map(cedar_type_to_json_type), + kind: match e.data.node { + EntityDecl::Enum(d) => json_schema::EntityTypeKind::Enum { + choices: d.choices.map(|n| n.node), + }, + EntityDecl::Standard(d) => { + // First build up the defined entity type + json_schema::EntityTypeKind::Standard(json_schema::StandardEntityType { + member_of_types: d.member_of_types.into_iter().map(RawName::from).collect(), + shape: convert_attr_decls(d.attrs), + tags: d.tags.map(cedar_type_to_json_type), + }) + } + }, annotations: e.annotations.into(), - loc: Some(e.data.loc), + loc: Some(e.data.loc.clone()), }; // Then map over all of the bound names - collect_all_errors(e.data.node.names.into_iter().map( - move |name| -> Result<_, ToJsonSchemaErrors> { Ok((convert_id(name)?, etype.clone())) }, - )) + collect_all_errors( + names + .into_iter() + .map(move |name| -> Result<_, ToJsonSchemaErrors> { + Ok((convert_id(name)?, etype.clone())) + }), + ) } /// Create a [`json_schema::AttributesOrContext`] from a series of `AttrDecl`s @@ -480,7 +488,7 @@ impl NamespaceRecord { let entities = collect_decls( entities .into_iter() - .flat_map(|decl| decl.names.clone()) + .flat_map(|decl| decl.names().cloned()) .map(extract_name), )?; // Ensure no duplicate actions @@ -673,6 +681,7 @@ fn into_partition_decls( mod preserves_source_locations { use super::*; use cool_asserts::assert_matches; + use json_schema::{EntityType, EntityTypeKind}; #[test] fn entity_action_and_common_type_decls() { @@ -792,10 +801,10 @@ mod preserves_source_locations { .get(&Some(Name::parse_unqualified_name("NS").unwrap())) .expect("couldn't find namespace NS"); - let entityC = ns + assert_matches!(ns .entity_types .get(&"C".parse().unwrap()) - .expect("couldn't find entity C"); + .expect("couldn't find entity C"), EntityType { kind: EntityTypeKind::Standard(entityC), ..} => { assert_matches!(entityC.member_of_types.first().unwrap().loc(), Some(loc) => { assert_matches!(loc.snippet(), Some("A")); }); @@ -830,7 +839,7 @@ mod preserves_source_locations { assert_matches!(loc.snippet(), Some("B")); }); }); - }); + });}); let ctypeAA = ns .common_types diff --git a/cedar-policy-validator/src/coreschema.rs b/cedar-policy-validator/src/coreschema.rs index 4c7fc342b..e5c47250e 100644 --- a/cedar-policy-validator/src/coreschema.rs +++ b/cedar-policy-validator/src/coreschema.rs @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use crate::{ValidatorActionId, ValidatorEntityType, ValidatorSchema}; -use cedar_policy_core::ast::{EntityType, EntityUID}; +use crate::{ValidatorActionId, ValidatorEntityType, ValidatorEntityTypeKind, ValidatorSchema}; +use cedar_policy_core::ast::{Eid, EntityType, EntityUID}; +use cedar_policy_core::entities::conformance::err::InvalidEnumEntityError; +use cedar_policy_core::entities::conformance::{ + is_valid_enumerated_entity, validate_euids_in_partial_value, +}; use cedar_policy_core::extensions::{ExtensionFunctionLookupError, Extensions}; use cedar_policy_core::{ast, entities}; use miette::Diagnostic; +use nonempty::NonEmpty; use smol_str::SmolStr; use std::collections::hash_map::Values; use std::collections::HashSet; @@ -106,6 +111,13 @@ impl EntityTypeDescription { } impl entities::EntityTypeDescription for EntityTypeDescription { + fn enum_entity_eids(&self) -> Option> { + match &self.validator_type.kind { + ValidatorEntityTypeKind::Enum(choices) => Some(choices.clone().map(|s| Eid::new(s))), + _ => None, + } + } + fn entity_type(&self) -> ast::EntityType { self.core_type.clone() } @@ -143,10 +155,10 @@ impl entities::EntityTypeDescription for EntityTypeDescription { fn required_attrs<'s>(&'s self) -> Box + 's> { Box::new( self.validator_type - .attributes - .iter() + .attributes() + .into_iter() .filter(|(_, ty)| ty.is_required) - .map(|(attr, _)| attr.clone()), + .map(|(attr, _)| attr), ) } @@ -155,7 +167,7 @@ impl entities::EntityTypeDescription for EntityTypeDescription { } fn open_attributes(&self) -> bool { - self.validator_type.open_attributes.is_open() + self.validator_type.open_attributes().is_open() } } @@ -170,7 +182,18 @@ impl ast::RequestSchema for ValidatorSchema { // first check that principal and resource are of types that exist in // the schema, we can do this check even if action is unknown. if let Some(principal_type) = request.principal().get_type() { - if self.get_entity_type(principal_type).is_none() { + if let Some(et) = self.get_entity_type(principal_type) { + if let Some(euid) = request.principal().uid() { + if let ValidatorEntityType { + kind: ValidatorEntityTypeKind::Enum(choices), + .. + } = et + { + is_valid_enumerated_entity(&Vec::from(choices.clone().map(Eid::new)), euid) + .map_err(Self::Error::from)?; + } + } + } else { return Err(request_validation_errors::UndeclaredPrincipalTypeError { principal_ty: principal_type.clone(), } @@ -178,7 +201,18 @@ impl ast::RequestSchema for ValidatorSchema { } } if let Some(resource_type) = request.resource().get_type() { - if self.get_entity_type(resource_type).is_none() { + if let Some(et) = self.get_entity_type(resource_type) { + if let Some(euid) = request.resource().uid() { + if let ValidatorEntityType { + kind: ValidatorEntityTypeKind::Enum(choices), + .. + } = et + { + is_valid_enumerated_entity(&Vec::from(choices.clone().map(Eid::new)), euid) + .map_err(Self::Error::from)?; + } + } + } else { return Err(request_validation_errors::UndeclaredResourceTypeError { resource_ty: resource_type.clone(), } @@ -201,6 +235,11 @@ impl ast::RequestSchema for ValidatorSchema { validator_action_id.check_resource_type(principal_type, action)?; } if let Some(context) = request.context() { + validate_euids_in_partial_value( + &CoreSchema::new(&self), + &context.clone().into(), + ) + .map_err(|err| RequestValidationError::InvalidEnumEntity(err))?; let expected_context_ty = validator_action_id.context_type(); if !expected_context_ty .typecheck_partial_value(&context.clone().into(), extensions) @@ -308,6 +347,11 @@ pub enum RequestValidationError { #[error("context is not valid: {0}")] #[diagnostic(transparent)] TypeOfContext(ExtensionFunctionLookupError), + /// Error when a principal or resource entity is of an enumerated entity + /// type but has an invalid EID + #[error(transparent)] + #[diagnostic(transparent)] + InvalidEnumEntity(#[from] InvalidEnumEntityError), } /// Errors related to validation @@ -564,10 +608,29 @@ pub fn context_schema_for_action( #[cfg(test)] mod test { use super::*; + use ast::{Context, Value}; use cedar_policy_core::test_utils::{expect_err, ExpectedErrorMessageBuilder}; use cool_asserts::assert_matches; use serde_json::json; + #[track_caller] + fn schema_with_enums() -> ValidatorSchema { + let src = r#" + entity Fruit enum ["🍉", "🍓", "🍒"]; + entity People; + action "eat" appliesTo { + principal: [People], + resource: [Fruit], + context: { + fruit?: Fruit, + } + }; + "#; + ValidatorSchema::from_cedarschema_str(src, Extensions::none()) + .expect("should be a valid schema") + .0 + } + fn schema() -> ValidatorSchema { let src = json!( { "": { @@ -1040,4 +1103,64 @@ mod test { } ); } + + #[test] + fn enumerated_entity_type() { + assert_matches!( + ast::Request::new( + ( + ast::EntityUID::with_eid_and_type("People", "😋").unwrap(), + None + ), + ( + ast::EntityUID::with_eid_and_type("Action", "eat").unwrap(), + None + ), + ( + ast::EntityUID::with_eid_and_type("Fruit", "🍉").unwrap(), + None + ), + Context::empty(), + Some(&schema_with_enums()), + Extensions::none(), + ), + Ok(_) + ); + assert_matches!( + ast::Request::new( + (ast::EntityUID::with_eid_and_type("People", "🤔").unwrap(), None), + (ast::EntityUID::with_eid_and_type("Action", "eat").unwrap(), None), + (ast::EntityUID::with_eid_and_type("Fruit", "🥝").unwrap(), None), + Context::empty(), + Some(&schema_with_enums()), + Extensions::none(), + ), + Err(e) => { + expect_err( + "", + &miette::Report::new(e), + &ExpectedErrorMessageBuilder::error(r#"entity `Fruit::"🥝"` is of an enumerated entity type, but `"🥝"` is not declared as a valid eid"#).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + ); + } + ); + assert_matches!( + ast::Request::new( + (ast::EntityUID::with_eid_and_type("People", "🤔").unwrap(), None), + (ast::EntityUID::with_eid_and_type("Action", "eat").unwrap(), None), + (ast::EntityUID::with_eid_and_type("Fruit", "🍉").unwrap(), None), + Context::from_pairs(std::iter::once(("fruit".into(), (Value::from(ast::EntityUID::with_eid_and_type("Fruit", "🥭").unwrap())).into())), Extensions::none()).expect("should be a valid context"), + Some(&schema_with_enums()), + Extensions::none(), + ), + Err(e) => { + expect_err( + "", + &miette::Report::new(e), + &ExpectedErrorMessageBuilder::error(r#"entity `Fruit::"🥭"` is of an enumerated entity type, but `"🥭"` is not declared as a valid eid"#).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + ); + } + ); + } } diff --git a/cedar-policy-validator/src/diagnostics.rs b/cedar-policy-validator/src/diagnostics.rs index 7d11f3ebe..21e8b09b3 100644 --- a/cedar-policy-validator/src/diagnostics.rs +++ b/cedar-policy-validator/src/diagnostics.rs @@ -17,6 +17,7 @@ //! This module contains the diagnostics (i.e., errors and warnings) that are //! returned by the validator. +use cedar_policy_core::entities::conformance::err::InvalidEnumEntityError; use miette::Diagnostic; use thiserror::Error; use validation_errors::UnrecognizedActionIdHelp; @@ -166,6 +167,11 @@ pub enum ValidationError { #[error(transparent)] #[diagnostic(transparent)] InternalInvariantViolation(#[from] validation_errors::InternalInvariantViolation), + /// Returned when an entity literal is of an enumerated entity type but has + /// undeclared UID + #[error(transparent)] + #[diagnostic(transparent)] + InvalidEnumEntity(#[from] validation_errors::InvalidEnumEntity), #[cfg(feature = "level-validate")] /// If a entity dereference level was provided, the policies cannot deref /// more than `level` hops away from PARX @@ -398,6 +404,19 @@ impl ValidationError { } .into() } + + pub(crate) fn invalid_enum_entity( + source_loc: Option, + policy_id: PolicyID, + err: InvalidEnumEntityError, + ) -> Self { + validation_errors::InvalidEnumEntity { + source_loc, + policy_id, + err, + } + .into() + } } /// Represents the different kinds of validation warnings and information diff --git a/cedar-policy-validator/src/diagnostics/validation_errors.rs b/cedar-policy-validator/src/diagnostics/validation_errors.rs index e696c9ae6..d82b79a38 100644 --- a/cedar-policy-validator/src/diagnostics/validation_errors.rs +++ b/cedar-policy-validator/src/diagnostics/validation_errors.rs @@ -16,6 +16,7 @@ //! Defines errors returned by the validator. +use cedar_policy_core::entities::conformance::err::InvalidEnumEntityError; use miette::Diagnostic; use thiserror::Error; @@ -717,6 +718,27 @@ impl Display for AttributeAccess { } } +/// Returned when an entity literal is of an enumerated entity type but has +/// undeclared UID +#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)] +#[error("for policy `{policy_id}`: {err}")] +pub struct InvalidEnumEntity { + /// Source location + pub source_loc: Option, + /// Policy ID where the error occurred + pub policy_id: PolicyID, + /// The error + pub err: InvalidEnumEntityError, +} + +impl Diagnostic for InvalidEnumEntity { + impl_diagnostic_from_source_loc_opt_field!(source_loc); + + fn help<'a>(&'a self) -> Option> { + self.err.help() + } +} + // These tests all assume that the typechecker found an error while checking the // outermost `GetAttr` in the expressions. If the attribute didn't exist at all, // only the primary message would included in the final error. If it was an diff --git a/cedar-policy-validator/src/entity_manifest/type_annotations.rs b/cedar-policy-validator/src/entity_manifest/type_annotations.rs index 0bf4b071e..a08e8a569 100644 --- a/cedar-policy-validator/src/entity_manifest/type_annotations.rs +++ b/cedar-policy-validator/src/entity_manifest/type_annotations.rs @@ -145,7 +145,12 @@ impl AccessTrie { .ok_or(MismatchedNotStrictSchemaError {})?, ) .ok_or(MismatchedNotStrictSchemaError {})?; - &entity_ty.attributes + &Attributes::with_required_attributes( + entity_ty + .attributes() + .into_iter() + .map(|(attr, ty)| (attr, ty.attr_type)), + ) } EntityRecordKind::ActionEntity { name: _, attrs } => attrs, }; diff --git a/cedar-policy-validator/src/json_schema.rs b/cedar-policy-validator/src/json_schema.rs index 25b7f536a..6969f5457 100644 --- a/cedar-policy-validator/src/json_schema.rs +++ b/cedar-policy-validator/src/json_schema.rs @@ -25,7 +25,8 @@ use cedar_policy_core::{ FromNormalizedStr, }; use educe::Educe; -use nonempty::nonempty; +use itertools::Itertools; +use nonempty::{nonempty, NonEmpty}; use serde::{ de::{MapAccess, Visitor}, ser::SerializeMap, @@ -431,6 +432,25 @@ impl NamespaceDefinition { } } +/// The kind of entity type. There are currently two kinds: The standard entity +/// type specified by [`StandardEntityType`] and the enumerated entity type +/// proposed by RFC 53 +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +#[serde(untagged)] +#[cfg_attr(feature = "wasm", derive(tsify::Tsify))] +#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))] +pub enum EntityTypeKind { + /// The standard entity type specified by [`StandardEntityType`] + Standard(StandardEntityType), + /// The enumerated entity type: An entity type that can only have a + /// nonempty set of possible EIDs + Enum { + #[serde(rename = "enum")] + /// The nonempty set of possible EIDs + choices: NonEmpty, + }, +} + /// Represents the full definition of an entity type in the schema. /// Entity types describe the relationships in the entity store, including what /// entities can be members of groups of what types, and what attributes @@ -439,27 +459,13 @@ impl NamespaceDefinition { /// The parameter `N` is the type of entity type names and common type names in /// this [`EntityType`], including recursively. /// See notes on [`Fragment`]. -#[derive(Educe, Debug, Clone, Serialize, Deserialize)] +#[derive(Educe, Debug, Clone, Serialize)] #[educe(PartialEq, Eq)] #[serde(bound(deserialize = "N: Deserialize<'de> + From"))] -#[serde(deny_unknown_fields)] -#[serde(rename_all = "camelCase")] -#[cfg_attr(feature = "wasm", derive(tsify::Tsify))] -#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))] pub struct EntityType { - /// Entities of this [`EntityType`] are allowed to be members of entities of - /// these types. - #[serde(default)] - #[serde(skip_serializing_if = "Vec::is_empty")] - pub member_of_types: Vec, - /// Description of the attributes for entities of this [`EntityType`]. - #[serde(default)] - #[serde(skip_serializing_if = "AttributesOrContext::is_empty_record")] - pub shape: AttributesOrContext, - /// Tag type for entities of this [`EntityType`]; `None` means entities of this [`EntityType`] do not have tags. - #[serde(default)] - #[serde(skip_serializing_if = "Option::is_none")] - pub tags: Option>, + /// The referred type + #[serde(flatten)] + pub kind: EntityTypeKind, /// Annotations #[serde(default)] #[serde(skip_serializing_if = "Annotations::is_empty")] @@ -474,24 +480,166 @@ pub struct EntityType { pub loc: Option, } +impl<'de, N: Deserialize<'de> + From> Deserialize<'de> for EntityType { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + // A "real" option that does not accept `null` during deserialization + enum RealOption { + Some(T), + None, + } + impl<'de, T: Deserialize<'de>> Deserialize<'de> for RealOption { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + T::deserialize(deserializer).map(Self::Some) + } + } + impl Default for RealOption { + fn default() -> Self { + Self::None + } + } + + impl From> for Option { + fn from(value: RealOption) -> Self { + match value { + RealOption::Some(v) => Self::Some(v), + RealOption::None => None, + } + } + } + + // A struct that contains all possible fields of entity type + // I tried to apply the same idea to `EntityTypeKind` but serde allows + // unknown fields + #[derive(Deserialize)] + #[serde(bound(deserialize = "N: Deserialize<'de> + From"))] + #[serde(deny_unknown_fields)] + #[serde(rename_all = "camelCase")] + struct Everything { + #[serde(default)] + member_of_types: RealOption>, + #[serde(default)] + shape: RealOption>, + #[serde(default)] + tags: RealOption>, + #[serde(default)] + #[serde(rename = "enum")] + choices: RealOption>, + #[serde(default)] + annotations: Annotations, + } + + let value: Everything = Everything::deserialize(deserializer)?; + // We favor the "enum" key here. That is, when we observe this key, we + // assume the entity type is an enumerated one and hence reports fields + // of standard entity types as invalid. + if let Some(choices) = value.choices.into() { + let mut unexpected_fields: Vec<&str> = vec![]; + if Option::>::from(value.member_of_types).is_some() { + unexpected_fields.push("memberOfTypes"); + } + if Option::>::from(value.shape).is_some() { + unexpected_fields.push("shape"); + } + if Option::>::from(value.tags).is_some() { + unexpected_fields.push("tags"); + } + if !unexpected_fields.is_empty() { + return Err(serde::de::Error::custom(format!( + "unexpected field: {}", + unexpected_fields.into_iter().join(", ") + ))); + } + Ok(EntityType { + kind: EntityTypeKind::Enum { choices }, + annotations: value.annotations, + loc: None, + }) + } else { + Ok(EntityType { + kind: EntityTypeKind::Standard(StandardEntityType { + member_of_types: Option::from(value.member_of_types).unwrap_or_default(), + shape: Option::from(value.shape).unwrap_or_default(), + tags: Option::from(value.tags), + }), + annotations: value.annotations, + loc: None, + }) + } + } +} + +/// The "standard" entity type. That is, an entity type defined by parent +/// entity types, shape, and tags. +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Deserialize)] +#[serde(bound(deserialize = "N: Deserialize<'de> + From"))] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "wasm", derive(tsify::Tsify))] +#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))] +pub struct StandardEntityType { + /// Entities of this [`StandardEntityType`] are allowed to be members of entities of + /// these types. + #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default)] + pub member_of_types: Vec, + /// Description of the attributes for entities of this [`StandardEntityType`]. + #[serde(skip_serializing_if = "AttributesOrContext::is_empty_record")] + #[serde(default)] + pub shape: AttributesOrContext, + /// Tag type for entities of this [`StandardEntityType`]; `None` means entities of this [`StandardEntityType`] do not have tags. + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub tags: Option>, +} + +#[cfg(test)] +impl From> for EntityType { + fn from(value: StandardEntityType) -> Self { + Self { + kind: EntityTypeKind::Standard(value), + annotations: Annotations::new(), + loc: None, + } + } +} + impl EntityType { /// (Conditionally) prefix unqualified entity and common type references with the namespace they are in pub fn conditionally_qualify_type_references( self, ns: Option<&InternalName>, ) -> EntityType { - EntityType { - member_of_types: self - .member_of_types - .into_iter() - .map(|rname| rname.conditionally_qualify_with(ns, ReferenceType::Entity)) // Only entity, not common, here for now; see #1064 - .collect(), - shape: self.shape.conditionally_qualify_type_references(ns), - tags: self - .tags - .map(|ty| ty.conditionally_qualify_type_references(ns)), - annotations: self.annotations, - loc: self.loc, + let Self { + kind, + annotations, + loc, + } = self; + match kind { + EntityTypeKind::Enum { choices } => EntityType { + kind: EntityTypeKind::Enum { choices }, + annotations, + loc, + }, + EntityTypeKind::Standard(ty) => EntityType { + kind: EntityTypeKind::Standard(StandardEntityType { + member_of_types: ty + .member_of_types + .into_iter() + .map(|rname| rname.conditionally_qualify_with(ns, ReferenceType::Entity)) // Only entity, not common, here for now; see #1064 + .collect(), + shape: ty.shape.conditionally_qualify_type_references(ns), + tags: ty + .tags + .map(|ty| ty.conditionally_qualify_type_references(ns)), + }), + annotations, + loc, + }, } } } @@ -507,19 +655,33 @@ impl EntityType { self, all_defs: &AllDefs, ) -> std::result::Result, TypeNotDefinedError> { - Ok(EntityType { - member_of_types: self - .member_of_types - .into_iter() - .map(|cname| cname.resolve(all_defs)) - .collect::>()?, - shape: self.shape.fully_qualify_type_references(all_defs)?, - tags: self - .tags - .map(|ty| ty.fully_qualify_type_references(all_defs)) - .transpose()?, - annotations: self.annotations, - loc: self.loc, + let Self { + kind, + annotations, + loc, + } = self; + Ok(match kind { + EntityTypeKind::Enum { choices } => EntityType { + kind: EntityTypeKind::Enum { choices }, + annotations, + loc, + }, + EntityTypeKind::Standard(ty) => EntityType { + kind: EntityTypeKind::Standard(StandardEntityType { + member_of_types: ty + .member_of_types + .into_iter() + .map(|cname| cname.resolve(all_defs)) + .collect::>()?, + shape: ty.shape.fully_qualify_type_references(all_defs)?, + tags: ty + .tags + .map(|ty| ty.fully_qualify_type_references(all_defs)) + .transpose()?, + }), + annotations, + loc, + }, }) } } @@ -1993,7 +2155,7 @@ mod test { "memberOfTypes" : ["UserGroup"] } "#; - let et = serde_json::from_str::>(user).expect("Parse Error"); + assert_matches!(serde_json::from_str::>(user), Ok(EntityType { kind: EntityTypeKind::Standard(et), .. }) => { assert_eq!(et.member_of_types, vec!["UserGroup".parse().unwrap()]); assert_eq!( et.shape, @@ -2004,7 +2166,7 @@ mod test { }), loc: None }), - ); + );}); } #[test] @@ -2012,7 +2174,7 @@ mod test { let src = r#" { } "#; - let et = serde_json::from_str::>(src).expect("Parse Error"); + assert_matches!(serde_json::from_str::>(src), Ok(EntityType { kind: EntityTypeKind::Standard(et), .. }) => { assert_eq!(et.member_of_types.len(), 0); assert_eq!( et.shape, @@ -2023,7 +2185,7 @@ mod test { }), loc: None }), - ); + );}); } #[test] @@ -2819,15 +2981,15 @@ mod entity_tags { fn basic() { let json = example_json_schema(); assert_matches!(Fragment::from_json_value(json), Ok(frag) => { - let user = &frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(); - assert_matches!(&user.tags, Some(Type::Type { ty: TypeVariant::Set { element }, loc: None }) => { - assert_matches!(&**element, Type::Type { ty: TypeVariant::String, loc: None }); // TODO: why is this `TypeVariant::String` in this case but `EntityOrCommon { "String" }` in all the other cases in this test? Do we accept common types as the element type for sets? + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(user), ..} => { + assert_matches!(&user.tags, Some(Type::Type { ty: TypeVariant::Set { element }, ..}) => { + assert_matches!(&**element, Type::Type{ ty: TypeVariant::String, ..}); // TODO: why is this `TypeVariant::String` in this case but `EntityOrCommon { "String" }` in all the other cases in this test? Do we accept common types as the element type for sets? + });}); + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"Document".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(doc), ..} => { + assert_matches!(&doc.tags, Some(Type::Type { ty: TypeVariant::Set { element }, ..}) => { + assert_matches!(&**element, Type::Type{ ty: TypeVariant::String, ..}); // TODO: why is this `TypeVariant::String` in this case but `EntityOrCommon { "String" }` in all the other cases in this test? Do we accept common types as the element type for sets? }); - let doc = &frag.0.get(&None).unwrap().entity_types.get(&"Document".parse().unwrap()).unwrap(); - assert_matches!(&doc.tags, Some(Type::Type { ty: TypeVariant::Set { element }, loc: None }) => { - assert_matches!(&**element, Type::Type { ty: TypeVariant::String, loc: None }); // TODO: why is this `TypeVariant::String` in this case but `EntityOrCommon { "String" }` in all the other cases in this test? Do we accept common types as the element type for sets? - }); - }) + })}) } /// In this schema, the tag type is a common type @@ -2853,11 +3015,11 @@ mod entity_tags { "actions": {} }}); assert_matches!(Fragment::from_json_value(json), Ok(frag) => { - let user = &frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(); - assert_matches!(&user.tags, Some(Type::CommonTypeRef { type_name, loc: None }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(), EntityType {kind: EntityTypeKind::Standard(user), ..} => { + assert_matches!(&user.tags, Some(Type::CommonTypeRef { type_name, .. }) => { assert_eq!(&format!("{type_name}"), "T"); }); - }) + })}); } /// In this schema, the tag type is an entity type @@ -2880,11 +3042,11 @@ mod entity_tags { "actions": {} }}); assert_matches!(Fragment::from_json_value(json), Ok(frag) => { - let user = &frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(); - assert_matches!(&user.tags, Some(Type::Type { ty: TypeVariant::Entity { name }, loc: None }) => { + assert_matches!(frag.0.get(&None).unwrap().entity_types.get(&"User".parse().unwrap()).unwrap(), EntityType { kind: EntityTypeKind::Standard(user), ..} => { + assert_matches!(&user.tags, Some(Type::Type{ ty: TypeVariant::Entity{ name }, ..}) => { assert_eq!(&format!("{name}"), "User"); }); - }) + })}); } /// This schema has `tags` inside `shape` instead of parallel to it @@ -2952,15 +3114,17 @@ mod test_json_roundtrip { [( "a".parse().unwrap(), EntityType { - member_of_types: vec!["a".parse().unwrap()], - shape: AttributesOrContext(Type::Type { - ty: TypeVariant::Record(RecordType { - attributes: BTreeMap::new(), - additional_attributes: false, + kind: EntityTypeKind::Standard(StandardEntityType { + member_of_types: vec!["a".parse().unwrap()], + shape: AttributesOrContext(Type::Type { + ty: TypeVariant::Record(RecordType { + attributes: BTreeMap::new(), + additional_attributes: false, + }), + loc: None, }), - loc: None, + tags: None, }), - tags: None, annotations: Annotations::new(), loc: None, }, @@ -2999,15 +3163,17 @@ mod test_json_roundtrip { [( "a".parse().unwrap(), EntityType { - member_of_types: vec!["a".parse().unwrap()], - shape: AttributesOrContext(Type::Type { - ty: TypeVariant::Record(RecordType { - attributes: BTreeMap::new(), - additional_attributes: false, + kind: EntityTypeKind::Standard(StandardEntityType { + member_of_types: vec!["a".parse().unwrap()], + shape: AttributesOrContext(Type::Type { + ty: TypeVariant::Record(RecordType { + attributes: BTreeMap::new(), + additional_attributes: false, + }), + loc: None, }), - loc: None, + tags: None, }), - tags: None, annotations: Annotations::new(), loc: None, }, @@ -3396,7 +3562,8 @@ mod annotations { }); } - const ENTITY_TYPE_EXPECTED_ATTRIBUTES: &str = "`memberOfTypes`, `shape`, `tags`, `annotations`"; + const ENTITY_TYPE_EXPECTED_ATTRIBUTES: &str = + "`memberOfTypes`, `shape`, `tags`, `enum`, `annotations`"; const NAMESPACE_EXPECTED_ATTRIBUTES: &str = "`commonTypes`, `entityTypes`, `actions`, `annotations`"; const ATTRIBUTE_TYPE_EXPECTED_ATTRIBUTES: &str = @@ -3567,7 +3734,7 @@ mod annotations { "annotations": { "foo": "" }, - "bar": 1 + "bar": 1, } } } @@ -3620,3 +3787,89 @@ mod ord { }); } } + +#[cfg(test)] +mod enumerated_entity_types { + use cool_asserts::assert_matches; + + use crate::{ + json_schema::{EntityType, EntityTypeKind, Fragment}, + RawName, + }; + + #[test] + fn basic() { + let src = serde_json::json!({ + "": { + "entityTypes": { + "Foo": { + "enum": ["foo", "bar"], + "annotations": { + "a": "b", + } + }, + }, + "actions": {}, + } + }); + let schema: Result, _> = serde_json::from_value(src); + assert_matches!(schema, Ok(frag) => { + assert_matches!(&frag.0[&None].entity_types[&"Foo".parse().unwrap()], EntityType { + kind: EntityTypeKind::Enum {choices}, + .. + } => { + assert_eq!(Vec::from(choices.clone()), ["foo", "bar"]); + }); + }); + + let src = serde_json::json!({ + "": { + "entityTypes": { + "Foo": { + "enum": [], + "annotations": { + "a": "b", + } + }, + }, + "actions": {}, + } + }); + let schema: Result, _> = serde_json::from_value(src); + assert_matches!(schema, Err(errs) => { + // TODO: write our own error messages if it's deemed to be too ugly. + assert_eq!(errs.to_string(), "the vector provided was empty, NonEmpty needs at least one element"); + }); + + let src = serde_json::json!({ + "": { + "entityTypes": { + "Foo": { + "enum": null, + }, + }, + "actions": {}, + } + }); + let schema: Result, _> = serde_json::from_value(src); + assert_matches!(schema, Err(errs) => { + assert_eq!(errs.to_string(), "invalid type: null, expected a sequence"); + }); + + let src = serde_json::json!({ + "": { + "entityTypes": { + "Foo": { + "enum": ["foo"], + "memberOfTypes": ["bar"], + }, + }, + "actions": {}, + } + }); + let schema: Result, _> = serde_json::from_value(src); + assert_matches!(schema, Err(errs) => { + assert_eq!(errs.to_string(), "unexpected field: memberOfTypes"); + }); + } +} diff --git a/cedar-policy-validator/src/lib.rs b/cedar-policy-validator/src/lib.rs index 711851106..eae1a0a15 100644 --- a/cedar-policy-validator/src/lib.rs +++ b/cedar-policy-validator/src/lib.rs @@ -202,6 +202,7 @@ impl Validator { } else { Some( self.validate_entity_types(p) + .chain(self.validate_enum_entity(p)) .chain(self.validate_action_ids(p)) // We could usefully update this pass to apply to partial // schema if it only failed when there is a known action @@ -289,23 +290,21 @@ mod test { [ ( foo_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ( bar_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ], [( @@ -576,3 +575,222 @@ mod test { ); } } + +#[cfg(test)] +mod enumerated_entity_types { + use cedar_policy_core::{ + ast::{Eid, EntityUID, ExprBuilder, PolicyID}, + expr_builder::ExprBuilder as _, + extensions::Extensions, + parser::parse_policy_or_template, + }; + use cool_asserts::assert_matches; + use itertools::Itertools; + + use crate::{ + typecheck::test::test_utils::get_loc, + types::{EntityLUB, Type}, + validation_errors::AttributeAccess, + ValidationError, Validator, ValidatorSchema, + }; + + #[track_caller] + fn schema() -> ValidatorSchema { + ValidatorSchema::from_json_value( + serde_json::json!( + { + "": { "entityTypes": { + "Foo": { + "enum": [ "foo" ], + }, + "Bar": { + "memberOfTypes": ["Foo"], + } + }, + "actions": { + "a": { + "appliesTo": { + "principalTypes": ["Foo"], + "resourceTypes": ["Bar"], + } + } + } + } + } + ), + Extensions::none(), + ) + .unwrap() + } + + #[test] + fn basic() { + let schema = schema(); + let template = parse_policy_or_template(None, r#"permit(principal, action == Action::"a", resource) when { principal == Foo::"foo" };"#).unwrap(); + let validator = Validator::new(schema); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert!(errors.collect_vec().is_empty()); + } + + #[test] + fn basic_invalid() { + let schema = schema(); + let template = parse_policy_or_template(None, r#"permit(principal, action == Action::"a", resource) when { principal == Foo::"fo" };"#).unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "fo").unwrap()); + }); + + let template = parse_policy_or_template( + None, + r#"permit(principal == Foo::"🏈", action == Action::"a", resource);"#, + ) + .unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "🏈").unwrap()); + }); + + let template = parse_policy_or_template( + None, + r#"permit(principal in Foo::"🏈", action == Action::"a", resource);"#, + ) + .unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "🏈").unwrap()); + }); + + let template = parse_policy_or_template( + None, + r#"permit(principal, action == Action::"a", resource) + when { {"🏈": Foo::"🏈"} has "🏈" }; + "#, + ) + .unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "🏈").unwrap()); + }); + + let template = parse_policy_or_template( + None, + r#"permit(principal, action == Action::"a", resource) + when { [Foo::"🏈"].isEmpty() }; + "#, + ) + .unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "🏈").unwrap()); + }); + + let template = parse_policy_or_template( + None, + r#"permit(principal, action == Action::"a", resource) + when { [{"🏈": Foo::"🏈"}].isEmpty() }; + "#, + ) + .unwrap(); + let validator = Validator::new(schema.clone()); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_matches!(&errors.collect_vec(), [ValidationError::InvalidEnumEntity(err)] => { + assert_eq!(err.err.choices, vec![Eid::new("foo")]); + assert_eq!(err.err.uid, EntityUID::with_eid_and_type("Foo", "🏈").unwrap()); + }); + } + + #[test] + fn no_attrs_allowed() { + let schema = schema(); + let src = r#"permit(principal, action == Action::"a", resource) when { principal.foo == "foo" };"#; + let template = parse_policy_or_template(None, src).unwrap(); + let validator = Validator::new(schema); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_eq!( + errors.collect_vec(), + [ValidationError::unsafe_attribute_access( + get_loc(src, "principal.foo"), + PolicyID::from_string("policy0"), + AttributeAccess::EntityLUB( + EntityLUB::single_entity("Foo".parse().unwrap()), + vec!["foo".into()], + ), + None, + false, + )] + ); + } + + #[test] + fn no_ancestors() { + let schema = schema(); + let src = + r#"permit(principal, action == Action::"a", resource) when { principal in resource };"#; + let template = parse_policy_or_template(None, src).unwrap(); + let validator = Validator::new(schema); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_eq!( + errors.collect_vec(), + [ValidationError::hierarchy_not_respected( + get_loc(src, "principal in resource"), + PolicyID::from_string("policy0"), + Some("Foo".parse().unwrap()), + Some("Bar".parse().unwrap()), + )] + ); + } + + #[test] + fn no_tags_allowed() { + let schema = schema(); + let src = r#"permit(principal, action == Action::"a", resource) when { principal.getTag("foo") == "foo" };"#; + let template = parse_policy_or_template(None, src).unwrap(); + let validator = Validator::new(schema); + let (errors, warnings) = + validator.validate_policy(&template, crate::ValidationMode::Strict); + assert!(warnings.collect_vec().is_empty()); + assert_eq!( + errors.collect_vec(), + [ValidationError::unsafe_tag_access( + get_loc(src, r#"principal.getTag("foo")"#), + PolicyID::from_string("policy0"), + Some(EntityLUB::single_entity("Foo".parse().unwrap()),), + { + let builder = ExprBuilder::new(); + let mut expr = builder.val("foo"); + expr.set_data(Some(Type::primitive_string())); + expr + }, + )] + ); + } +} diff --git a/cedar-policy-validator/src/rbac.rs b/cedar-policy-validator/src/rbac.rs index 6e1be4ad7..4e66b5318 100644 --- a/cedar-policy-validator/src/rbac.rs +++ b/cedar-policy-validator/src/rbac.rs @@ -18,9 +18,10 @@ use cedar_policy_core::{ ast::{ - self, ActionConstraint, EntityReference, EntityUID, Policy, PolicyID, PrincipalConstraint, - PrincipalOrResourceConstraint, ResourceConstraint, SlotEnv, Template, + self, ActionConstraint, Eid, EntityReference, EntityUID, Policy, PolicyID, + PrincipalConstraint, PrincipalOrResourceConstraint, ResourceConstraint, SlotEnv, Template, }, + entities::conformance::is_valid_enumerated_entity, fuzzy_match::fuzzy_search, parser::Loc, }; @@ -36,6 +37,34 @@ use crate::{ use super::{schema::*, Validator}; impl Validator { + /// Validate if a [`Template`] contains entities of enumerated entity types + /// but with invalid UIDs + pub(crate) fn validate_enum_entity<'a>( + &'a self, + template: &'a Template, + ) -> impl Iterator + 'a { + policy_entity_uids(template) + .filter(|e| !e.is_action()) + .filter_map(|e: &EntityUID| { + if let Some(ValidatorEntityType { + kind: ValidatorEntityTypeKind::Enum(choices), + .. + }) = self.schema.get_entity_type(e.entity_type()) + { + match is_valid_enumerated_entity(&Vec::from(choices.clone().map(Eid::new)), e) { + Ok(_) => {} + Err(err) => { + return Some(ValidationError::invalid_enum_entity( + e.loc().cloned(), + template.id().clone(), + err, + )); + } + }; + } + None + }) + } /// Generate `UnrecognizedEntityType` error for every entity type in the /// expression that could not also be found in the schema. pub(crate) fn validate_entity_types<'a>( @@ -482,13 +511,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( foo_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -519,13 +547,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( "foo_type".parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -609,13 +636,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( p_name.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -636,13 +662,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( p_name.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -663,13 +688,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( p_name.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -1031,13 +1055,12 @@ mod test { let schema_file = json_schema::NamespaceDefinition::new( [( foo_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), )], [], ); @@ -1067,23 +1090,21 @@ mod test { [ ( principal_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ( resource_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ], [( @@ -1464,43 +1485,39 @@ mod test { [ ( principal_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ( resource_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![resource_parent_type.parse().unwrap()], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ( resource_parent_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![resource_grandparent_type.parse().unwrap()], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ( resource_grandparent_type.parse().unwrap(), - json_schema::EntityType { + json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }, + } + .into(), ), ], [ diff --git a/cedar-policy-validator/src/schema.rs b/cedar-policy-validator/src/schema.rs index 2f48f621e..5c595a00f 100644 --- a/cedar-policy-validator/src/schema.rs +++ b/cedar-policy-validator/src/schema.rs @@ -27,7 +27,9 @@ use cedar_policy_core::{ parser::Loc, transitive_closure::compute_tc, }; +use entity_type::StandardValidatorEntityType; use itertools::Itertools; +use namespace_def::EntityTypeFragment; use nonempty::NonEmpty; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -52,7 +54,7 @@ mod action; pub use action::ValidatorActionId; pub(crate) use action::ValidatorApplySpec; mod entity_type; -pub use entity_type::ValidatorEntityType; +pub use entity_type::{ValidatorEntityType, ValidatorEntityTypeKind}; mod namespace_def; pub(crate) use namespace_def::try_jsonschema_type_into_validator_type; pub use namespace_def::ValidatorNamespaceDef; @@ -493,7 +495,7 @@ impl ValidatorSchema { // to get a `children` relation. let mut entity_children: HashMap> = HashMap::new(); for (name, entity_type) in entity_type_fragments.iter() { - for parent in entity_type.parents.iter() { + for parent in entity_type.parents() { entity_children .entry(internal_name_to_entity_type(parent.clone())?) .or_default() @@ -514,34 +516,53 @@ impl ValidatorSchema { // error for any other undeclared entity types by // `check_for_undeclared`. let descendants = entity_children.remove(&name).unwrap_or_default(); - let (attributes, open_attributes) = { - let unresolved = try_jsonschema_type_into_validator_type( - entity_type.attributes.0, - extensions, - )?; - Self::record_attributes_or_none( - unresolved.resolve_common_type_refs(&common_types)?, - ) - .ok_or_else(|| ContextOrShapeNotRecordError { - ctx_or_shape: ContextOrShape::EntityTypeShape(name.clone()), - })? - }; - let tags = entity_type - .tags - .map(|tags| try_jsonschema_type_into_validator_type(tags, extensions)) - .transpose()? - .map(|unresolved| unresolved.resolve_common_type_refs(&common_types)) - .transpose()?; - Ok(( - name.clone(), - ValidatorEntityType { - name, - descendants, + match entity_type { + EntityTypeFragment::Enum(choices) => Ok(( + name.clone(), + ValidatorEntityType { + name, + descendants, + kind: ValidatorEntityTypeKind::Enum(choices), + }, + )), + EntityTypeFragment::Standard { attributes, - open_attributes, + parents: _, tags, - }, - )) + } => { + let (attributes, open_attributes) = { + let unresolved = + try_jsonschema_type_into_validator_type(attributes.0, extensions)?; + Self::record_attributes_or_none( + unresolved.resolve_common_type_refs(&common_types)?, + ) + .ok_or_else(|| { + ContextOrShapeNotRecordError { + ctx_or_shape: ContextOrShape::EntityTypeShape(name.clone()), + } + })? + }; + let tags = tags + .map(|tags| try_jsonschema_type_into_validator_type(tags, extensions)) + .transpose()? + .map(|unresolved| unresolved.resolve_common_type_refs(&common_types)) + .transpose()?; + Ok(( + name.clone(), + ValidatorEntityType { + name, + descendants, + kind: ValidatorEntityTypeKind::Standard( + StandardValidatorEntityType { + attributes, + open_attributes, + tags, + }, + ), + }, + )) + } + } }) .collect::>>()?; @@ -2354,7 +2375,7 @@ pub(crate) mod test { .unwrap(); let schema: ValidatorSchema = fragment.try_into().unwrap(); assert_eq!( - schema.entity_types.iter().next().unwrap().1.attributes, + schema.entity_types.iter().next().unwrap().1.attributes(), Attributes::with_required_attributes([("a".into(), Type::primitive_long())]) ); } @@ -2380,7 +2401,7 @@ pub(crate) mod test { .unwrap(); let schema: ValidatorSchema = fragment.try_into().unwrap(); assert_eq!( - schema.entity_types.iter().next().unwrap().1.attributes, + schema.entity_types.iter().next().unwrap().1.attributes(), Attributes::with_required_attributes([("a".into(), Type::primitive_long())]) ); } @@ -2412,7 +2433,7 @@ pub(crate) mod test { .unwrap(); let schema: ValidatorSchema = fragment.try_into().unwrap(); assert_eq!( - schema.entity_types.iter().next().unwrap().1.attributes, + schema.entity_types.iter().next().unwrap().1.attributes(), Attributes::with_required_attributes([("a".into(), Type::primitive_long())]) ); } @@ -2458,7 +2479,7 @@ pub(crate) mod test { .unwrap(); assert_eq!( - schema.entity_types.iter().next().unwrap().1.attributes, + schema.entity_types.iter().next().unwrap().1.attributes(), Attributes::with_required_attributes([("a".into(), Type::primitive_long())]) ); } @@ -2842,7 +2863,9 @@ pub(crate) mod test { } ); let schema = ValidatorSchema::from_json_value(src, Extensions::all_available()).unwrap(); - let mut attributes = assert_entity_type_exists(&schema, "Demo::User").attributes(); + let mut attributes = assert_entity_type_exists(&schema, "Demo::User") + .attributes() + .into_iter(); let (attr_name, attr_ty) = attributes.next().unwrap(); assert_eq!(attr_name, "id"); assert_eq!(&attr_ty.attr_type, &Type::primitive_string()); @@ -4215,29 +4238,29 @@ mod test_rfc70 { "; let schema = assert_valid_cedar_schema(src); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_string()); }); - assert_matches!(e.attributes.get_attr("c"), Some(atype) => { + assert_matches!(e.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(e.attributes.get_attr("d"), Some(atype) => { + assert_matches!(e.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_boolean()); }); - assert_matches!(f.attributes.get_attr("c"), Some(atype) => { + assert_matches!(f.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(f.attributes.get_attr("d"), Some(atype) => { + assert_matches!(f.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); @@ -4324,29 +4347,29 @@ mod test_rfc70 { }); let schema = assert_valid_json_schema(src_json); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_string()); }); - assert_matches!(e.attributes.get_attr("c"), Some(atype) => { + assert_matches!(e.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(e.attributes.get_attr("d"), Some(atype) => { + assert_matches!(e.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_boolean()); }); - assert_matches!(f.attributes.get_attr("c"), Some(atype) => { + assert_matches!(f.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(f.attributes.get_attr("d"), Some(atype) => { + assert_matches!(f.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); } @@ -4375,29 +4398,29 @@ mod test_rfc70 { "; let schema = assert_valid_cedar_schema(src); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("ipaddr".parse().unwrap())); }); - assert_matches!(e.attributes.get_attr("c"), Some(atype) => { + assert_matches!(e.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(e.attributes.get_attr("d"), Some(atype) => { + assert_matches!(e.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("decimal".parse().unwrap())); }); - assert_matches!(f.attributes.get_attr("c"), Some(atype) => { + assert_matches!(f.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(f.attributes.get_attr("d"), Some(atype) => { + assert_matches!(f.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); @@ -4443,29 +4466,29 @@ mod test_rfc70 { }); let schema = assert_valid_json_schema(src_json); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("ipaddr".parse().unwrap())); }); - assert_matches!(e.attributes.get_attr("c"), Some(atype) => { + assert_matches!(e.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(e.attributes.get_attr("d"), Some(atype) => { + assert_matches!(e.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); // using the common type definition }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("decimal".parse().unwrap())); }); - assert_matches!(f.attributes.get_attr("c"), Some(atype) => { + assert_matches!(f.attr("c"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); - assert_matches!(f.attributes.get_attr("d"), Some(atype) => { + assert_matches!(f.attr("d"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_long()); }); } @@ -4490,17 +4513,17 @@ mod test_rfc70 { "; let schema = assert_valid_cedar_schema(src); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("String")); }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_string()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("NS::Bool")); // using the common type definition }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_boolean()); }); @@ -4538,17 +4561,17 @@ mod test_rfc70 { }); let schema = assert_valid_json_schema(src_json); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("String")); }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_string()); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("NS::Bool")); }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::primitive_boolean()); }); } @@ -4573,17 +4596,17 @@ mod test_rfc70 { "; let schema = assert_valid_cedar_schema(src); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("ipaddr")); }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("ipaddr".parse().unwrap())); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("NS::decimal")); }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("decimal".parse().unwrap())); }); @@ -4621,17 +4644,17 @@ mod test_rfc70 { }); let schema = assert_valid_json_schema(src_json); let e = assert_entity_type_exists(&schema, "E"); - assert_matches!(e.attributes.get_attr("a"), Some(atype) => { + assert_matches!(e.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("ipaddr")); }); - assert_matches!(e.attributes.get_attr("b"), Some(atype) => { + assert_matches!(e.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("ipaddr".parse().unwrap())); }); let f = assert_entity_type_exists(&schema, "NS::F"); - assert_matches!(f.attributes.get_attr("a"), Some(atype) => { + assert_matches!(f.attr("a"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::named_entity_reference_from_str("NS::decimal")); }); - assert_matches!(f.attributes.get_attr("b"), Some(atype) => { + assert_matches!(f.attr("b"), Some(atype) => { assert_eq!(&atype.attr_type, &Type::extension("decimal".parse().unwrap())); }); } diff --git a/cedar-policy-validator/src/schema/entity_type.rs b/cedar-policy-validator/src/schema/entity_type.rs index 5c15da2df..ef98751cd 100644 --- a/cedar-policy-validator/src/schema/entity_type.rs +++ b/cedar-policy-validator/src/schema/entity_type.rs @@ -16,9 +16,10 @@ //! This module contains the definition of `ValidatorEntityType` +use nonempty::NonEmpty; use serde::Serialize; use smol_str::SmolStr; -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; use cedar_policy_core::{ast::EntityType, transitive_closure::TCNode}; @@ -44,6 +45,12 @@ pub struct ValidatorEntityType { /// descendants before it is used in any validation. pub descendants: HashSet, + /// The kind of entity type: enumerated and standard + pub kind: ValidatorEntityTypeKind, +} + +#[derive(Clone, Debug, Serialize)] +pub struct StandardValidatorEntityType { /// The attributes associated with this entity. pub(crate) attributes: Attributes, @@ -60,23 +67,69 @@ pub struct ValidatorEntityType { pub(crate) tags: Option, } +/// The kind of validator entity types +/// It can either be a standard (non-enum) entity type, or +/// an enumerated entity type +#[derive(Clone, Debug, Serialize)] +pub enum ValidatorEntityTypeKind { + /// Standard, aka non-enum + Standard(StandardValidatorEntityType), + /// Enumerated + Enum(NonEmpty), +} + impl ValidatorEntityType { + /// Return `true` if this entity type has an [`EntityType`] declared as a + /// possible descendant in the schema. + pub fn has_descendant_entity_type(&self, ety: &EntityType) -> bool { + self.descendants.contains(ety) + } + + /// An iterator over the attributes of this entity + pub fn attributes(&self) -> Attributes { + match &self.kind { + ValidatorEntityTypeKind::Enum(_) => Attributes { + attrs: BTreeMap::new(), + }, + ValidatorEntityTypeKind::Standard(ty) => Attributes::with_attributes( + ty.attributes() + .map(|(key, value)| (key.clone(), value.clone())), + ), + } + } + /// Get the type of the attribute with the given name, if it exists pub fn attr(&self, attr: &str) -> Option<&AttributeType> { - self.attributes.get_attr(attr) + match &self.kind { + ValidatorEntityTypeKind::Enum(_) => None, + ValidatorEntityTypeKind::Standard(ty) => ty.attributes.get_attr(attr), + } + } + + /// Get the open attributes + pub fn open_attributes(&self) -> OpenTag { + match &self.kind { + ValidatorEntityTypeKind::Enum(_) => OpenTag::ClosedAttributes, + ValidatorEntityTypeKind::Standard(ty) => ty.open_attributes, + } + } + + /// Get the type of tags on this entity. `None` indicates that entities of + /// this type are not allowed to have tags. + pub fn tag_type(&self) -> Option<&Type> { + match &self.kind { + ValidatorEntityTypeKind::Enum(_) => None, + ValidatorEntityTypeKind::Standard(ty) => ty.tag_type(), + } } +} +impl StandardValidatorEntityType { /// An iterator over the attributes of this entity pub fn attributes(&self) -> impl Iterator { self.attributes.iter() } - /// Return `true` if this entity type has an [`EntityType`] declared as a - /// possible descendant in the schema. - pub fn has_descendant_entity_type(&self, ety: &EntityType) -> bool { - self.descendants.contains(ety) - } - /// Get the type of tags on this entity. `None` indicates that entities of /// this type are not allowed to have tags. pub fn tag_type(&self) -> Option<&Type> { @@ -105,9 +158,15 @@ impl TCNode for ValidatorEntityType { #[cfg(feature = "protobufs")] impl From<&ValidatorEntityType> for proto::ValidatorEntityType { fn from(v: &ValidatorEntityType) -> Self { - let tags = v.tags.as_ref().map(|tags| proto::Tag { + let tags = v.tag_type().map(|tags| proto::Tag { optional_type: Some(proto::Type::from(tags)), }); + let enums: Vec = match &v.kind { + ValidatorEntityTypeKind::Enum(choices) => { + choices.into_iter().map(|s| s.to_string()).collect() + } + ValidatorEntityTypeKind::Standard(_) => vec![], + }; Self { name: Some(ast::proto::EntityType::from(&v.name)), descendants: v @@ -115,9 +174,16 @@ impl From<&ValidatorEntityType> for proto::ValidatorEntityType { .iter() .map(ast::proto::EntityType::from) .collect(), - attributes: Some(proto::Attributes::from(&v.attributes)), - open_attributes: proto::OpenTag::from(&v.open_attributes).into(), + attributes: Some(proto::Attributes::from( + &Attributes::with_required_attributes( + v.attributes() + .into_iter() + .map(|(attr, ty)| (attr, ty.attr_type)), + ), + )), + open_attributes: proto::OpenTag::from(&v.open_attributes()).into(), tags, + enums, } } } @@ -138,15 +204,17 @@ impl From<&proto::ValidatorEntityType> for ValidatorEntityType { .expect("`as_ref()` for field that should exist"), ), descendants: v.descendants.iter().map(ast::EntityType::from).collect(), - attributes: Attributes::from( - v.attributes - .as_ref() - .expect("`as_ref()` for field that should exist"), - ), - open_attributes: OpenTag::from( - &proto::OpenTag::try_from(v.open_attributes).expect("decode should succeed"), - ), - tags, + kind: ValidatorEntityTypeKind::Standard(StandardValidatorEntityType { + attributes: Attributes::from( + v.attributes + .as_ref() + .expect("`as_ref()` for field that should exist"), + ), + open_attributes: OpenTag::from( + &proto::OpenTag::try_from(v.open_attributes).expect("decode should succeed"), + ), + tags, + }), } } } diff --git a/cedar-policy-validator/src/schema/namespace_def.rs b/cedar-policy-validator/src/schema/namespace_def.rs index d413f6a9a..e1c3334e4 100644 --- a/cedar-policy-validator/src/schema/namespace_def.rs +++ b/cedar-policy-validator/src/schema/namespace_def.rs @@ -36,7 +36,7 @@ use smol_str::{SmolStr, ToSmolStr}; use super::{internal_name_to_entity_type, AllDefs, ValidatorApplySpec}; use crate::{ err::{schema_errors::*, SchemaError}, - json_schema::{self, CommonTypeId}, + json_schema::{self, CommonTypeId, EntityTypeKind}, types::{AttributeType, Attributes, OpenTag, Type}, ActionBehavior, ConditionalName, RawName, ReferenceType, }; @@ -449,29 +449,41 @@ impl EntityTypesDef { /// references in `parents`, `attributes`, and `tags` may or may not be fully /// qualified yet, depending on `N`. #[derive(Debug, Clone)] -pub struct EntityTypeFragment { - /// Description of the attribute types for this entity type. - /// - /// This may contain references to common types which have not yet been - /// resolved/inlined (e.g., because they are not defined in this schema - /// fragment). - /// In the extreme case, this may itself be just a common type pointing to a - /// `Record` type defined in another fragment. - pub(super) attributes: json_schema::AttributesOrContext, - /// Direct parent entity types for this entity type. - /// These entity types may be declared in a different namespace or schema - /// fragment. - /// - /// We will check for undeclared parent types when combining fragments into - /// a [`crate::ValidatorSchema`]. - pub(super) parents: HashSet, - /// Tag type for this entity type. `None` means no tags are allowed on this - /// entity type. - /// - /// This may contain references to common types which have not yet been - /// resolved/inlined (e.g., because they are not defined in this schema - /// fragment). - pub(super) tags: Option>, +pub enum EntityTypeFragment { + Standard { + /// Description of the attribute types for this entity type. + /// + /// This may contain references to common types which have not yet been + /// resolved/inlined (e.g., because they are not defined in this schema + /// fragment). + /// In the extreme case, this may itself be just a common type pointing to a + /// `Record` type defined in another fragment. + attributes: json_schema::AttributesOrContext, + /// Direct parent entity types for this entity type. + /// These entity types may be declared in a different namespace or schema + /// fragment. + /// + /// We will check for undeclared parent types when combining fragments into + /// a [`crate::ValidatorSchema`]. + parents: HashSet, + /// Tag type for this entity type. `None` means no tags are allowed on this + /// entity type. + /// + /// This may contain references to common types which have not yet been + /// resolved/inlined (e.g., because they are not defined in this schema + /// fragment). + tags: Option>, + }, + Enum(NonEmpty), +} + +impl EntityTypeFragment { + pub(crate) fn parents(&self) -> Box + '_> { + match self { + Self::Standard { parents, .. } => Box::new(parents.iter()), + Self::Enum(_) => Box::new(std::iter::empty()), + } + } } impl EntityTypeFragment { @@ -482,21 +494,24 @@ impl EntityTypeFragment { schema_file_type: json_schema::EntityType, schema_namespace: Option<&InternalName>, ) -> Self { - Self { - attributes: schema_file_type - .shape - .conditionally_qualify_type_references(schema_namespace), - parents: schema_file_type - .member_of_types - .into_iter() - .map(|raw_name| { - // Only entity, not common, here for now; see #1064 - raw_name.conditionally_qualify_with(schema_namespace, ReferenceType::Entity) - }) - .collect(), - tags: schema_file_type - .tags - .map(|tags| tags.conditionally_qualify_type_references(schema_namespace)), + match schema_file_type.kind { + EntityTypeKind::Enum { choices } => Self::Enum(choices), + EntityTypeKind::Standard(ty) => Self::Standard { + attributes: ty + .shape + .conditionally_qualify_type_references(schema_namespace), + parents: ty + .member_of_types + .into_iter() + .map(|raw_name| { + // Only entity, not common, here for now; see #1064 + raw_name.conditionally_qualify_with(schema_namespace, ReferenceType::Entity) + }) + .collect(), + tags: ty + .tags + .map(|tags| tags.conditionally_qualify_type_references(schema_namespace)), + }, } } @@ -510,51 +525,63 @@ impl EntityTypeFragment { self, all_defs: &AllDefs, ) -> Result, TypeNotDefinedError> { - // Fully qualify typenames appearing in `attributes` - let fully_qual_attributes = self.attributes.fully_qualify_type_references(all_defs); - // Fully qualify typenames appearing in `parents` - let parents: HashSet = self - .parents - .into_iter() - .map(|parent| parent.resolve(all_defs)) - .collect::>()?; - // Fully qualify typenames appearing in `tags` - let fully_qual_tags = self - .tags - .map(|tags| tags.fully_qualify_type_references(all_defs)) - .transpose(); - // Now is the time to check whether any parents are dangling, i.e., - // refer to entity types that are not declared in any fragment (since we - // now have the set of typenames that are declared in all fragments). - let undeclared_parents: Option> = NonEmpty::collect( - parents - .iter() - .filter(|ety| !all_defs.is_defined_as_entity(ety)) - .map(|ety| ConditionalName::unconditional(ety.clone(), ReferenceType::Entity)), - ); - match (fully_qual_attributes, fully_qual_tags, undeclared_parents) { - (Ok(attributes), Ok(tags), None) => Ok(EntityTypeFragment { + match self { + Self::Enum(choices) => Ok(EntityTypeFragment::Enum(choices)), + Self::Standard { attributes, parents, tags, - }), - (Ok(_), Ok(_), Some(undeclared_parents)) => Err(TypeNotDefinedError { - undefined_types: undeclared_parents, - }), - (Err(e), Ok(_), None) | (Ok(_), Err(e), None) => Err(e), - (Err(e1), Err(e2), None) => Err(TypeNotDefinedError::join_nonempty(nonempty![e1, e2])), - (Err(e), Ok(_), Some(mut undeclared)) | (Ok(_), Err(e), Some(mut undeclared)) => { - undeclared.extend(e.undefined_types); - Err(TypeNotDefinedError { - undefined_types: undeclared, - }) - } - (Err(e1), Err(e2), Some(mut undeclared)) => { - undeclared.extend(e1.undefined_types); - undeclared.extend(e2.undefined_types); - Err(TypeNotDefinedError { - undefined_types: undeclared, - }) + } => { + // Fully qualify typenames appearing in `attributes` + let fully_qual_attributes = attributes.fully_qualify_type_references(all_defs); + // Fully qualify typenames appearing in `parents` + let parents: HashSet = parents + .into_iter() + .map(|parent| parent.resolve(all_defs)) + .collect::>()?; + // Fully qualify typenames appearing in `tags` + let fully_qual_tags = tags + .map(|tags| tags.fully_qualify_type_references(all_defs)) + .transpose(); + // Now is the time to check whether any parents are dangling, i.e., + // refer to entity types that are not declared in any fragment (since we + // now have the set of typenames that are declared in all fragments). + let undeclared_parents: Option> = NonEmpty::collect( + parents + .iter() + .filter(|ety| !all_defs.is_defined_as_entity(ety)) + .map(|ety| { + ConditionalName::unconditional(ety.clone(), ReferenceType::Entity) + }), + ); + match (fully_qual_attributes, fully_qual_tags, undeclared_parents) { + (Ok(attributes), Ok(tags), None) => Ok(EntityTypeFragment::Standard { + attributes, + parents, + tags, + }), + (Ok(_), Ok(_), Some(undeclared_parents)) => Err(TypeNotDefinedError { + undefined_types: undeclared_parents, + }), + (Err(e), Ok(_), None) | (Ok(_), Err(e), None) => Err(e), + (Err(e1), Err(e2), None) => { + Err(TypeNotDefinedError::join_nonempty(nonempty![e1, e2])) + } + (Err(e), Ok(_), Some(mut undeclared)) + | (Ok(_), Err(e), Some(mut undeclared)) => { + undeclared.extend(e.undefined_types); + Err(TypeNotDefinedError { + undefined_types: undeclared, + }) + } + (Err(e1), Err(e2), Some(mut undeclared)) => { + undeclared.extend(e1.undefined_types); + undeclared.extend(e2.undefined_types); + Err(TypeNotDefinedError { + undefined_types: undeclared, + }) + } + } } } } diff --git a/cedar-policy-validator/src/typecheck/test/expr.rs b/cedar-policy-validator/src/typecheck/test/expr.rs index 4ae01e0cf..fa88f457a 100644 --- a/cedar-policy-validator/src/typecheck/test/expr.rs +++ b/cedar-policy-validator/src/typecheck/test/expr.rs @@ -22,7 +22,6 @@ use std::{str::FromStr, vec}; use cedar_policy_core::{ ast::{BinaryOp, EntityUID, Expr, Pattern, PatternElem, SlotId, Var}, - est::Annotations, extensions::Extensions, }; use itertools::Itertools; @@ -67,13 +66,12 @@ fn slot_typechecks() { #[test] fn slot_in_typechecks() { - let etype = json_schema::EntityType { + let etype = json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }; + } + .into(); let schema = json_schema::NamespaceDefinition::new([("typename".parse().unwrap(), etype)], []); assert_typechecks_for_mode( schema.clone(), @@ -99,13 +97,12 @@ fn slot_in_typechecks() { #[test] fn slot_equals_typechecks() { - let etype = json_schema::EntityType { + let etype = json_schema::StandardEntityType { member_of_types: vec![], shape: json_schema::AttributesOrContext::default(), tags: None, - annotations: Annotations::new(), - loc: None, - }; + } + .into(); // These don't typecheck in strict mode because the test_util expression // typechecker doesn't have access to a schema, so it can't link // the template slots with appropriate types. Similar policies that pass diff --git a/cedar-policy-validator/src/types.rs b/cedar-policy-validator/src/types.rs index 60fc3a065..095f0c986 100644 --- a/cedar-policy-validator/src/types.rs +++ b/cedar-policy-validator/src/types.rs @@ -949,7 +949,7 @@ impl EntityLUB { let mut lub_element_attributes = self.lub_elements.iter().map(|name| { schema .get_entity_type(name) - .map(|entity_type| entity_type.attributes.clone()) + .map(|entity_type| entity_type.attributes()) .unwrap_or_else(|| Attributes::with_attributes(None)) }); @@ -1295,7 +1295,7 @@ impl EntityRecordKind { EntityRecordKind::Entity(lub) => lub.iter().any(|e_name| { schema .get_entity_type(e_name) - .map(|e_type| e_type.open_attributes) + .map(|e_type| e_type.open_attributes()) // The entity type was not found in the schema, so we know // nothing about it and must assume that it may have // additional attributes. diff --git a/cedar-policy/CHANGELOG.md b/cedar-policy/CHANGELOG.md index 3744ebf74..04e93de23 100644 --- a/cedar-policy/CHANGELOG.md +++ b/cedar-policy/CHANGELOG.md @@ -16,6 +16,7 @@ Cedar Language Version: TBD ### Added - Added ability to remove `Entity`s from an `Entities` struct (resolving #701) +- Implemented [RFC 53 (enumerated entity types)](https://github.com/cedar-policy/rfcs/blob/main/text/0053-enum-entities.md) (#1377) ### Fixed diff --git a/cedar-policy/src/api/err.rs b/cedar-policy/src/api/err.rs index b619680bf..417d43c39 100644 --- a/cedar-policy/src/api/err.rs +++ b/cedar-policy/src/api/err.rs @@ -436,6 +436,10 @@ pub enum ValidationError { #[error(transparent)] #[diagnostic(transparent)] EntityDerefLevelViolation(#[from] validation_errors::EntityDerefLevelViolation), + /// Returned when an entity is of an enumerated entity type but has invalid EID + #[error(transparent)] + #[diagnostic(transparent)] + InvalidEnumEntity(#[from] validation_errors::InvalidEnumEntity), } impl ValidationError { @@ -459,6 +463,7 @@ impl ValidationError { Self::HierarchyNotRespected(e) => e.policy_id(), Self::InternalInvariantViolation(e) => e.policy_id(), Self::EntityDerefLevelViolation(e) => e.policy_id(), + Self::InvalidEnumEntity(e) => e.policy_id(), } } } @@ -515,6 +520,9 @@ impl From for ValidationError { cedar_policy_validator::ValidationError::InternalInvariantViolation(e) => { Self::InternalInvariantViolation(e.into()) } + cedar_policy_validator::ValidationError::InvalidEnumEntity(e) => { + Self::InvalidEnumEntity(e.into()) + } #[cfg(feature = "level-validate")] cedar_policy_validator::ValidationError::EntityDerefLevelViolation(e) => { Self::EntityDerefLevelViolation(e.into()) @@ -1077,6 +1085,11 @@ pub enum RequestValidationError { #[error(transparent)] #[diagnostic(transparent)] TypeOfContext(#[from] request_validation_errors::TypeOfContextError), + /// Error when a principal or resource entity is of an enumerated entity + /// type but has an invalid EID + #[error(transparent)] + #[diagnostic(transparent)] + InvalidEnumEntity(#[from] request_validation_errors::InvalidEnumEntityError), } #[doc(hidden)] @@ -1104,6 +1117,9 @@ impl From for RequestValidationE cedar_policy_validator::RequestValidationError::TypeOfContext(e) => { Self::TypeOfContext(e.into()) } + cedar_policy_validator::RequestValidationError::InvalidEnumEntity(e) => { + Self::InvalidEnumEntity(e.into()) + } } } } @@ -1229,6 +1245,15 @@ pub mod request_validation_errors { #[error(transparent)] #[diagnostic(transparent)] pub struct TypeOfContextError(#[from] ExtensionFunctionLookupError); + + /// Error when a principal or resource entity is of an enumerated entity + /// type but has an invalid EID + #[derive(Debug, Diagnostic, Error)] + #[error(transparent)] + #[diagnostic(transparent)] + pub struct InvalidEnumEntityError( + #[from] cedar_policy_core::entities::conformance::err::InvalidEnumEntityError, + ); } /// An error generated by entity slicing. diff --git a/cedar-policy/src/api/err/validation_errors.rs b/cedar-policy/src/api/err/validation_errors.rs index ca14df220..0faec02af 100644 --- a/cedar-policy/src/api/err/validation_errors.rs +++ b/cedar-policy/src/api/err/validation_errors.rs @@ -73,3 +73,4 @@ wrap_core_error!(EntityDerefLevelViolation); wrap_core_error!(EmptySetForbidden); wrap_core_error!(NonLitExtConstructor); wrap_core_error!(InternalInvariantViolation); +wrap_core_error!(InvalidEnumEntity); diff --git a/cedar-policy/src/tests.rs b/cedar-policy/src/tests.rs index c6c958586..6e1c6e325 100644 --- a/cedar-policy/src/tests.rs +++ b/cedar-policy/src/tests.rs @@ -3825,6 +3825,255 @@ mod schema_based_parsing_tests { Err(EntitiesError::TransitiveClosureError(_)) )); } + + #[test] + fn enumerated_entity_types() { + let schema = Schema::from_str( + r#" + entity Fruit enum ["🍉", "🍓", "🍒"]; + entity People { + fruit?: Fruit, + fruit_rec?: {name: Fruit}, + }; + entity DeliciousFruit in Fruit tags Fruit; + action "eat" appliesTo { + principal: [People], + resource: [Fruit], + }; + "#, + ) + .expect("should be a valid schema"); + // invalid eid + let json = serde_json::json!([ + { + "uid" : { + "type" : "Fruit", + "id" : "🥝" + }, + "attrs" : {}, + "parents": [] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : {}, + "parents": [] + } + ]); + assert_matches!(Entities::from_json_value(json.clone(), Some(&schema)), Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"entity `Fruit::"🥝"` is of an enumerated entity type, but `"🥝"` is not declared as a valid eid"#, + ) + .help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + ); + }); + // no attributes are allowed + let json = serde_json::json!([ + { + "uid" : { + "type" : "Fruit", + "id" : "🍉" + }, + "attrs" : { + "sweetness": "high", + }, + "parents": [] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : {}, + "parents": [] + } + ]); + assert_matches!(Entities::from_json_value(json.clone(), Some(&schema)), Err(EntitiesError::Deserialization(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"attribute `sweetness` on `Fruit::"🍉"` should not exist according to the schema"#, + ) + .build(), + ); + }); + // no parents are allowed + let json = serde_json::json!([ + { + "uid" : { + "type" : "Fruit", + "id" : "🍉" + }, + "attrs" : { + }, + "parents": [{"type": "Fruit", "id": "🍓"}] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : {}, + "parents": [] + } + ]); + assert_matches!(Entities::from_json_value(json.clone(), Some(&schema)), Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"`Fruit::"🍉"` is not allowed to have an ancestor of type `Fruit` according to the schema"#, + ) + .build(), + ); + }); + + // Reference to invalid eid in the `parents` field + let json = serde_json::json!([ + { + "uid" : { + "type" : "DeliciousFruit", + "id" : "🍉" + }, + "attrs" : { + }, + "parents": [{"type": "Fruit", "id": "🥝"}] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : {}, + "parents": [] + } + ]); + assert_matches!( + Entities::from_json_value(json.clone(), Some(&schema)), + Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"entity `Fruit::"🥝"` is of an enumerated entity type, but `"🥝"` is not declared as a valid eid"#, + ).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + );} + ); + + // Reference to invalid eid in the `attrs` field + let json = serde_json::json!([ + { + "uid" : { + "type" : "DeliciousFruit", + "id" : "🍍" + }, + "attrs" : { + }, + "parents": [{"type": "Fruit", "id": "🍉"}] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : { + "fruit": {"type": "Fruit", "id": "🍍"}, + }, + "parents": [] + } + ]); + assert_matches!( + Entities::from_json_value(json.clone(), Some(&schema)), + Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"entity `Fruit::"🍍"` is of an enumerated entity type, but `"🍍"` is not declared as a valid eid"#, + ).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + );} + ); + // Reference to invalid eid in the `attrs` field + let json = serde_json::json!([ + { + "uid" : { + "type" : "DeliciousFruit", + "id" : "🍍" + }, + "attrs" : { + }, + "parents": [{"type": "Fruit", "id": "🍉"}] + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : { + "fruit_rec": {"name": {"type": "Fruit", "id": "🥭"}}, + }, + "parents": [] + } + ]); + assert_matches!( + Entities::from_json_value(json.clone(), Some(&schema)), + Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"entity `Fruit::"🥭"` is of an enumerated entity type, but `"🥭"` is not declared as a valid eid"#, + ).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + );} + ); + // Reference to invalid eid in the `tags` field + let json = serde_json::json!([ + { + "uid" : { + "type" : "DeliciousFruit", + "id" : "🍍" + }, + "attrs" : { + }, + "parents": [{"type": "Fruit", "id": "🍉"}], + "tags": { + "mango": {"type": "Fruit", "id": "🥭"}, + } + }, + { + "uid" : { + "type" : "People", + "id" : "😋" + }, + "attrs" : { + "fruit_rec": {"name": {"type": "Fruit", "id": "🍉"}}, + }, + "parents": [] + } + ]); + assert_matches!( + Entities::from_json_value(json.clone(), Some(&schema)), + Err(EntitiesError::InvalidEntity(err)) => { + expect_err( + &json, + &Report::new(err), + &ExpectedErrorMessageBuilder::error( + r#"entity `Fruit::"🥭"` is of an enumerated entity type, but `"🥭"` is not declared as a valid eid"#, + ).help(r#"valid entity eids: "🍉", "🍓", "🍒""#) + .build(), + );} + ); + } } #[cfg(not(feature = "partial-validate"))] diff --git a/cedar-wasm/build-wasm.sh b/cedar-wasm/build-wasm.sh index dde967091..022e40f24 100755 --- a/cedar-wasm/build-wasm.sh +++ b/cedar-wasm/build-wasm.sh @@ -97,6 +97,8 @@ process_types_file() { echo "type SmolStr = string;" >> "$types_file" echo "export type TypeOfAttribute = Type & { required?: boolean };" >> "$types_file" echo "export type CommonType = Type & { annotations?: Annotations };" >> "$types_file" + echo "export type EntityType = EntityTypeKind & { annotations?: Annotations; };" >> "$types_file" + echo "export type NonEmpty = Array;" >> "$types_file" } check_types_file() {