From ca029ed1693b44c8653b5e3b47013df9cb1f75e2 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Mon, 9 Sep 2024 12:58:58 -0400 Subject: [PATCH 1/3] Improve safety to handle cases when data is missing, still bug --- backend/api/models/data.py | 8 +-- backend/api/models/safety.py | 7 ++- frontend/src/app/context/model.tsx | 21 ++------ frontend/src/app/model/[id]/tabs/safety.tsx | 57 +++++++++++++++------ frontend/src/app/types/model.ts | 22 ++++++++ frontend/src/app/types/safety.ts | 2 +- 6 files changed, 79 insertions(+), 38 deletions(-) create mode 100644 frontend/src/app/types/model.ts diff --git a/backend/api/models/data.py b/backend/api/models/data.py index bdaf697..d65eb8e 100644 --- a/backend/api/models/data.py +++ b/backend/api/models/data.py @@ -486,8 +486,8 @@ class ModelSafety(BaseModel): ---------- metrics : List[Metric] A list of individual metrics and their current status. - last_evaluated : datetime - A timestamp of when the model was last evaluated. + last_evaluated : Optional[str] + A timestamp of when the model was evaluated in ISO 8601 format. is_recently_evaluated : bool Whether the model was recently evaluated. overall_status : str @@ -495,6 +495,8 @@ class ModelSafety(BaseModel): """ metrics: List[Metric] - last_evaluated: str = Field(..., description="ISO 8601 formatted date string") + last_evaluated: Optional[str] = Field( + ..., description="ISO 8601 formatted date string" + ) is_recently_evaluated: bool overall_status: str diff --git a/backend/api/models/safety.py b/backend/api/models/safety.py index 60b3bdb..42a994b 100644 --- a/backend/api/models/safety.py +++ b/backend/api/models/safety.py @@ -94,7 +94,12 @@ async def get_model_safety(model_id: str) -> ModelSafety: # noqa: PLR0912 "collection", [] ) if not collection: - raise ValueError("No metrics found in performance data") + return ModelSafety( + metrics=[], + last_evaluated=None, + overall_status="Not evaluated", + is_recently_evaluated=False, + ) current_date: datetime = datetime.now() last_evaluated = datetime.fromisoformat(collection[0]["timestamps"][-1]) diff --git a/frontend/src/app/context/model.tsx b/frontend/src/app/context/model.tsx index 1524163..5cb10fb 100644 --- a/frontend/src/app/context/model.tsx +++ b/frontend/src/app/context/model.tsx @@ -3,24 +3,10 @@ import React, { createContext, useState, useContext, ReactNode, useCallback, useMemo, useEffect } from 'react'; import { ModelFacts } from '../types/facts'; import { Criterion, EvaluationFrequency } from '../types/evaluation-criteria'; +import { ModelData } from '../types/model'; import { useAuth } from './auth'; import { debounce, DebouncedFunc } from 'lodash'; -interface ModelBasicInfo { - name: string; - version: string; -} - -interface ModelData { - id: string; - endpoints: string[]; - basic_info: ModelBasicInfo; - facts: ModelFacts | null; - evaluation_criteria: Criterion[]; - evaluation_frequency: EvaluationFrequency | null; - overall_status: string; -} - interface ModelContextType { models: ModelData[]; fetchModels: () => Promise; @@ -77,7 +63,7 @@ export const ModelProvider: React.FC<{ children: ReactNode }> = ({ children }) = id, ...modelInfo, overall_status: safetyData.overall_status - }; + } as ModelData; })); setModels(modelArray); } catch (error) { @@ -101,13 +87,12 @@ export const ModelProvider: React.FC<{ children: ReactNode }> = ({ children }) = } const [modelData, safetyData, factsData] = await Promise.all([ - apiRequest(`/api/models/${id}`), + apiRequest(`/api/models/${id}`), apiRequest<{ overall_status: string }>(`/api/model/${id}/safety`), apiRequest(`/api/models/${id}/facts`) ]); const newModel: ModelData = { - id, ...modelData, overall_status: safetyData.overall_status, facts: factsData, diff --git a/frontend/src/app/model/[id]/tabs/safety.tsx b/frontend/src/app/model/[id]/tabs/safety.tsx index d51cdee..2cb0cf6 100644 --- a/frontend/src/app/model/[id]/tabs/safety.tsx +++ b/frontend/src/app/model/[id]/tabs/safety.tsx @@ -18,6 +18,8 @@ import { HStack, Skeleton, SkeletonText, + Alert, + AlertIcon, } from '@chakra-ui/react'; import { CheckCircleIcon, WarningIcon, InfoIcon } from '@chakra-ui/icons'; import { formatDistanceToNow, parseISO } from 'date-fns'; @@ -59,11 +61,28 @@ const ModelSafetyTab: React.FC = ({ modelId }) => { fetchModelSafety(); }, [fetchModelSafety]); + if (error) { + return ( + + + {error} + + ); + } + + const isNotEvaluated = safetyData?.overall_status === 'Not evaluated'; + return ( Model Safety Dashboard + {isNotEvaluated && !isLoading && ( + + + This model has not been evaluated yet. + + )} = ({ overallStatus, cardBgColor, borderColor, textColor, isLoading }) => { const tooltipLabel = overallStatus === 'No warnings' ? "All safety criteria have been met and the model has been recently evaluated." + : overallStatus === 'Not evaluated' + ? "The model has not been evaluated yet." : "One or more safety criteria have not been met or the model needs re-evaluation. Check the Safety Evaluation Checklist for details."; - const statusColor = overallStatus === 'No warnings' ? 'green' : 'red'; - const StatusIcon = overallStatus === 'No warnings' ? CheckCircleIcon : WarningIcon; + const statusColor = overallStatus === 'No warnings' ? 'green' : overallStatus === 'Not evaluated' ? 'gray' : 'red'; + const StatusIcon = overallStatus === 'No warnings' ? CheckCircleIcon : overallStatus === 'Not evaluated' ? InfoIcon : WarningIcon; return ( @@ -138,6 +159,8 @@ interface LastEvaluatedCardProps extends CardProps { const LastEvaluatedCard: React.FC = ({ lastEvaluated, isRecentlyEvaluated, cardBgColor, borderColor, textColor, isLoading }) => { const tooltipLabel = isRecentlyEvaluated ? "The model has been evaluated within the specified evaluation frequency threshold." + : lastEvaluated === null + ? "The model has not been evaluated yet." : "The model has not been evaluated recently and may need re-evaluation."; return ( @@ -146,16 +169,16 @@ const LastEvaluatedCard: React.FC = ({ lastEvaluated, is Time since last evaluation - {lastEvaluated ? formatDistanceToNow(lastEvaluated) + ' ago' : 'N/A'} + {lastEvaluated ? formatDistanceToNow(lastEvaluated) + ' ago' : 'Not evaluated'} {isRecentlyEvaluated !== undefined && ( - - {isRecentlyEvaluated ? 'Recent' : 'Needs Re-evaluation'} + + {lastEvaluated === null ? 'Not Evaluated' : isRecentlyEvaluated ? 'Recent' : 'Needs Re-evaluation'} - {isRecentlyEvaluated ? : } + {lastEvaluated === null ? : isRecentlyEvaluated ? : } )} @@ -172,15 +195,19 @@ interface SafetyMetricsCardProps extends CardProps { const SafetyMetricsCard: React.FC = ({ metrics, cardBgColor, borderColor, textColor, isLoading }) => ( Evaluation Checklist - - {isLoading ? ( - Array.from({ length: 3 }).map((_, index) => ( + {isLoading ? ( + + {Array.from({ length: 3 }).map((_, index) => ( - )) - ) : ( - metrics.map((metric, index) => ( + ))} + + ) : metrics.length === 0 ? ( + No metrics available. The model has not been evaluated yet. + ) : ( + + {metrics.map((metric, index) => ( = ({ metrics, cardBgCo - )) - )} - + ))} + + )} ); diff --git a/frontend/src/app/types/model.ts b/frontend/src/app/types/model.ts new file mode 100644 index 0000000..6c6cd80 --- /dev/null +++ b/frontend/src/app/types/model.ts @@ -0,0 +1,22 @@ +import { z } from 'zod'; + +import { ModelFacts } from './facts'; +import { CriterionSchema, EvaluationFrequencySchema } from './evaluation-criteria'; + +export const ModelBasicInfoSchema = z.object({ + name: z.string(), + version: z.string(), +}); + +export const ModelDataSchema = z.object({ + id: z.string(), + endpoints: z.array(z.string()), + basic_info: ModelBasicInfoSchema, + facts: z.custom().nullable(), + evaluation_criteria: z.array(CriterionSchema), + evaluation_frequency: EvaluationFrequencySchema.nullable(), + overall_status: z.string(), +}); + +export type ModelBasicInfo = z.infer; +export type ModelData = z.infer; diff --git a/frontend/src/app/types/safety.ts b/frontend/src/app/types/safety.ts index dc070bf..f5e80ff 100644 --- a/frontend/src/app/types/safety.ts +++ b/frontend/src/app/types/safety.ts @@ -3,7 +3,7 @@ import { MetricSchema } from './performance-metrics'; export const ModelSafetySchema = z.object({ metrics: z.array(MetricSchema), - last_evaluated: z.string(), + last_evaluated: z.string().nullable(), is_recently_evaluated: z.boolean(), overall_status: z.string() }); From 1ef76c6184f34ede3f621b7984cdff92efbae39b Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Wed, 2 Oct 2024 11:28:48 -0400 Subject: [PATCH 2/3] Add default data to fix the context issue --- backend/api/models/data.py | 58 +++++++++++++++++++++++++++++++- backend/api/models/evaluate.py | 8 +++++ backend/api/models/safety.py | 1 - backend/tests/test_app.py | 2 +- frontend/src/app/types/safety.ts | 2 +- 5 files changed, 67 insertions(+), 4 deletions(-) diff --git a/backend/api/models/data.py b/backend/api/models/data.py index d65eb8e..c589f68 100644 --- a/backend/api/models/data.py +++ b/backend/api/models/data.py @@ -3,11 +3,12 @@ import uuid from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Field, validator from api.models.config import EndpointConfig +from api.models.constants import METRIC_DISPLAY_NAMES from api.models.utils import deep_convert_numpy @@ -500,3 +501,58 @@ class ModelSafety(BaseModel): ) is_recently_evaluated: bool overall_status: str + + +def _default_data( + metric_name: str, +) -> Tuple[ModelFacts, EvaluationCriterion, EvaluationFrequency]: + """Create default data for the model. + + Parameters + ---------- + metric_name : str + The name of the metric to create default data for. + + Returns + ------- + ModelFacts, EvaluationCriterion, EvaluationFrequency + The default data for the model. + + """ + model_facts = ModelFacts( + name="This is the name of the model", + version="This is the version of the model", + type="This is the type of the model", + intended_use="This is the intended use of the model", + target_population="This is the target population of the model", + input_data=["This is the input data of the model"], + output_data="This is the output data of the model", + summary="This is the summary of the model", + mechanism_of_action="This is the mechanism of action of the model", + validation_and_performance=ValidationAndPerformance( + internal_validation="This is the internal validation of the model", + external_validation="This is the external validation of the model", + performance_in_subgroups=[ + "This is the performance in subgroups of the model" + ], + ), + uses_and_directions=["This is the uses and directions of the model"], + warnings=["This is the warnings of the model"], + other_information=OtherInformation( + approval_date="This is the approval date of the model", + license="This is the license of the model", + contact_information="This is the contact information of the model", + publication_link="This is the publication link of the model", + ), + ) + evaluation_criterion = EvaluationCriterion( + metric_name=metric_name, + display_name=METRIC_DISPLAY_NAMES[metric_name], + operator=ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + threshold=0.5, + ) + evaluation_frequency = EvaluationFrequency( + value=7, + unit="days", + ) + return model_facts, evaluation_criterion, evaluation_frequency diff --git a/backend/api/models/evaluate.py b/backend/api/models/evaluate.py index 88eebd3..a2e9647 100644 --- a/backend/api/models/evaluate.py +++ b/backend/api/models/evaluate.py @@ -21,6 +21,7 @@ EvaluationResult, ModelBasicInfo, ModelData, + _default_data, ) from api.models.db import DATA_DIR, load_model_data, save_model_data from api.models.utils import deep_convert_numpy @@ -194,11 +195,18 @@ def add_model(self, model_info: ModelBasicInfo) -> str: The unique ID of the newly added model. """ model_id = str(uuid.uuid4()) + metric_name = f"{self.config.metrics[0].type}_{self.config.metrics[0].name}" + model_facts, evaluation_criterion, evaluation_frequency = _default_data( + metric_name + ) model_data = ModelData( id=model_id, endpoint_name=self.name, basic_info=model_info, endpoints=[self.name], + facts=model_facts, + evaluation_criterion=evaluation_criterion, + evaluation_frequency=evaluation_frequency, ) save_model_data(model_id, model_data) self.data.models.append(model_id) diff --git a/backend/api/models/safety.py b/backend/api/models/safety.py index 42a994b..c910b1b 100644 --- a/backend/api/models/safety.py +++ b/backend/api/models/safety.py @@ -136,7 +136,6 @@ async def get_model_safety(model_id: str) -> ModelSafety: # noqa: PLR0912 passed=status == "met", ) ) - all_criteria_met = all(metric.status == "met" for metric in metrics) evaluation_frequency = model_data.evaluation_frequency or EvaluationFrequency( value=30, unit="days" diff --git a/backend/tests/test_app.py b/backend/tests/test_app.py index 50b334f..fdf6b94 100644 --- a/backend/tests/test_app.py +++ b/backend/tests/test_app.py @@ -10,7 +10,7 @@ from api.models.data import ModelFacts -BASE_URL = "http://localhost:8000" # Adjust this to your API's base URL +BASE_URL = "http://localhost:8001" # Adjust this to your API's base URL def api_request(method: str, endpoint: str, data: Dict = None) -> Dict: diff --git a/frontend/src/app/types/safety.ts b/frontend/src/app/types/safety.ts index f5e80ff..b1a0af2 100644 --- a/frontend/src/app/types/safety.ts +++ b/frontend/src/app/types/safety.ts @@ -2,7 +2,7 @@ import { z } from 'zod'; import { MetricSchema } from './performance-metrics'; export const ModelSafetySchema = z.object({ - metrics: z.array(MetricSchema), + metrics: z.array(MetricSchema).optional().default([]), last_evaluated: z.string().nullable(), is_recently_evaluated: z.boolean(), overall_status: z.string() From 3812dda4138824f4a3a0d990c6330ad5307dd13a Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Wed, 2 Oct 2024 11:30:43 -0400 Subject: [PATCH 3/3] Small fix to docstring --- backend/api/models/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/api/models/data.py b/backend/api/models/data.py index c589f68..864c02f 100644 --- a/backend/api/models/data.py +++ b/backend/api/models/data.py @@ -515,7 +515,7 @@ def _default_data( Returns ------- - ModelFacts, EvaluationCriterion, EvaluationFrequency + Tuple[ModelFacts, EvaluationCriterion, EvaluationFrequency] The default data for the model. """