diff --git a/contentctl/objects/abstract_security_content_objects/detection_abstract.py b/contentctl/objects/abstract_security_content_objects/detection_abstract.py index 1231548c..fc0505ce 100644 --- a/contentctl/objects/abstract_security_content_objects/detection_abstract.py +++ b/contentctl/objects/abstract_security_content_objects/detection_abstract.py @@ -383,21 +383,17 @@ def providing_technologies(self) -> List[ProvidingTechnology]: @computed_field @property def risk(self) -> list[dict[str, Any]]: - risk_objects: list[dict[str, str | int]] = [] - - for entity in self.rba.risk_objects: - risk_object: dict[str, str | int] = dict() - risk_object["risk_object_type"] = entity.type - risk_object["risk_object_field"] = entity.field - risk_object["risk_score"] = entity.score - risk_objects.append(risk_object) - - for entity in self.rba.threat_objects: - threat_object: dict[str, str] = dict() - threat_object["threat_object_field"] = entity.field - threat_object["threat_object_type"] = entity.type - risk_objects.append(threat_object) - return risk_objects + if self.rba is None: + raise Exception( + f"Attempting to serialize rba section of [{self.name}], however RBA section is None" + ) + """ + action.risk.param._risk + of the conf file only contains a list of dicts. We do not eant to + include the message here, so we do not return it. + """ + rba_dict = self.rba.model_dump() + return rba_dict["risk_objects"] + rba_dict["threat_objects"] @computed_field @property diff --git a/contentctl/objects/rba.py b/contentctl/objects/rba.py index d33da47c..a63c043e 100644 --- a/contentctl/objects/rba.py +++ b/contentctl/objects/rba.py @@ -1,9 +1,12 @@ -from enum import Enum -from pydantic import BaseModel, computed_field, Field +from __future__ import annotations + from abc import ABC -from typing import Set, Annotated -from contentctl.objects.enums import RiskSeverity +from enum import Enum +from typing import Annotated, Set +from pydantic import BaseModel, Field, computed_field, model_serializer + +from contentctl.objects.enums import RiskSeverity RiskScoreValue_Type = Annotated[int, Field(ge=1, le=100)] @@ -51,6 +54,28 @@ class RiskObject(BaseModel): def __hash__(self): return hash((self.field, self.type, self.score)) + def __lt__(self, other: RiskObject) -> bool: + if ( + f"{self.field}{self.type}{self.score}" + < f"{other.field}{other.type}{other.score}" + ): + return True + return False + + @model_serializer + def serialize_risk_object(self) -> dict[str, str | int]: + """ + We define this explicitly for two reasons, even though the automatic + serialization works correctly. First we want to enforce a specific + field order for reasons of readability. Second, some of the fields + actually have different names than they do in the object. + """ + return { + "risk_object_field": self.field, + "risk_object_type": self.type, + "risk_score": self.score, + } + class ThreatObject(BaseModel): field: str @@ -59,6 +84,24 @@ class ThreatObject(BaseModel): def __hash__(self): return hash((self.field, self.type)) + def __lt__(self, other: ThreatObject) -> bool: + if f"{self.field}{self.type}" < f"{other.field}{other.type}": + return True + return False + + @model_serializer + def serialize_threat_object(self) -> dict[str, str]: + """ + We define this explicitly for two reasons, even though the automatic + serialization works correctly. First we want to enforce a specific + field order for reasons of readability. Second, some of the fields + actually have different names than they do in the object. + """ + return { + "threat_object_field": self.field, + "threat_object_type": self.type, + } + class RBAObject(BaseModel, ABC): message: str @@ -94,3 +137,11 @@ def severity(self) -> RiskSeverity: raise Exception( f"Error getting severity - risk_score must be between 0-100, but was actually {self.risk_score}" ) + + @model_serializer + def serialize_rba(self) -> dict[str, str | list[dict[str, str | int]]]: + return { + "message": self.message, + "risk_objects": [obj.model_dump() for obj in sorted(self.risk_objects)], + "threat_objects": [obj.model_dump() for obj in sorted(self.threat_objects)], + }