diff --git a/lit_nlp/api/layout.py b/lit_nlp/api/layout.py index a8fc3984..8def69f0 100644 --- a/lit_nlp/api/layout.py +++ b/lit_nlp/api/layout.py @@ -209,7 +209,8 @@ def to_json(self) -> dtypes.JsonDict: description=( 'The default LIT layout, which includes the data table and data point ' 'editor, the performance and metrics, predictions, explanations, and ' - 'counterfactuals.'), + 'counterfactuals.' + ), ) DEFAULT_LAYOUTS = { diff --git a/lit_nlp/client/lib/utils.ts b/lit_nlp/client/lib/utils.ts index ef3c3aeb..99ca1cef 100644 --- a/lit_nlp/client/lib/utils.ts +++ b/lit_nlp/client/lib/utils.ts @@ -382,9 +382,10 @@ export function isBinaryClassification(litType: LitType) { return false; } -/** Returns if a LitType has a parent field. */ -export function hasParent(litType: LitType) { - return (litType as LitTypeWithParent).parent != null; +/** Returns if a LitType has a parent field which is found in the (data) spec */ +export function hasValidParent(litType: LitType, spec: Spec) { + const parent = (litType as LitTypeWithParent).parent; + return parent != null && spec[parent] != null; } /** diff --git a/lit_nlp/client/modules/curves_module.ts b/lit_nlp/client/modules/curves_module.ts index 8f6afedc..57ece176 100644 --- a/lit_nlp/client/modules/curves_module.ts +++ b/lit_nlp/client/modules/curves_module.ts @@ -18,19 +18,22 @@ // tslint:disable:no-new-decorators import '../elements/expansion_panel'; import '../elements/line_chart'; -import {customElement} from 'lit/decorators.js'; + import {html, TemplateResult} from 'lit'; +import {customElement} from 'lit/decorators.js'; import {action, computed, observable} from 'mobx'; -import {FacetsChange} from '../core/faceting_control'; + import {app} from '../core/app'; +import {FacetsChange} from '../core/faceting_control'; import {LitModule} from '../core/lit_module'; import {MulticlassPreds} from '../lib/lit_types'; +import {styles as sharedStyles} from '../lib/shared_styles.css'; import {GroupedExamples, IndexedInput, ModelInfoMap, SCROLL_SYNC_CSS_CLASS, Spec} from '../lib/types'; -import {doesOutputSpecContain, findSpecKeys, hasParent} from '../lib/utils'; -import {GroupService} from '../services/services'; +import {findSpecKeys, hasValidParent} from '../lib/utils'; import {NumericFeatureBins} from '../services/group_service'; +import {GroupService} from '../services/services'; + import {styles} from './curves_module.css'; -import {styles as sharedStyles} from '../lib/shared_styles.css'; // Response from backend curves interpreter. interface CurvesResponse { @@ -367,7 +370,16 @@ export class CurvesModule extends LitModule { static override shouldDisplayModule( modelSpecs: ModelInfoMap, datasetSpec: Spec) { - return doesOutputSpecContain(modelSpecs, MulticlassPreds, hasParent); + // We need a MulticlassPreds field, where parent is in the dataset spec. + for (const modelInfo of Object.values(modelSpecs)) { + const outputSpec = modelInfo.spec.output; + for (const outputFieldName of findSpecKeys(outputSpec, MulticlassPreds)) { + if (hasValidParent(outputSpec[outputFieldName], datasetSpec)) { + return true; + } + } + } + return false; } } diff --git a/lit_nlp/components/metrics.py b/lit_nlp/components/metrics.py index 5c04944c..d4d49aac 100644 --- a/lit_nlp/components/metrics.py +++ b/lit_nlp/components/metrics.py @@ -56,7 +56,15 @@ def map_pred_keys( logging.info("Skipping '%s': No parent provided.", pred_key) continue - parent_spec: Optional[LitType] = data_spec.get(parent_key) + if parent_key not in data_spec: + logging.info( + "Skipping '%s': parent field '%s' not found in dataset.", + pred_key, + parent_key, + ) + continue + + parent_spec: LitType = data_spec[parent_key] if predicate(pred_spec, parent_spec): ret[pred_key] = parent_key else: