From 8084a47c9b2f064994072d3953f395a78179ea85 Mon Sep 17 00:00:00 2001 From: Nicholas Pun <182540099+nicholaspun-wandb@users.noreply.github.com> Date: Sun, 16 Feb 2025 17:10:51 -0800 Subject: [PATCH] fix(app): obtain mask-specific class labels from the run config --- weave-js/src/components/Panel2/PanelImage.tsx | 36 +++++++++- .../src/components/Panel2/controlsImage.ts | 72 ++++++++++++++----- 2 files changed, 88 insertions(+), 20 deletions(-) diff --git a/weave-js/src/components/Panel2/PanelImage.tsx b/weave-js/src/components/Panel2/PanelImage.tsx index 361632869931..1e1c2d4d6f4a 100644 --- a/weave-js/src/components/Panel2/PanelImage.tsx +++ b/weave-js/src/components/Panel2/PanelImage.tsx @@ -1,7 +1,10 @@ import {BoundingBoxSliderControl} from '@wandb/weave/common/components/MediaCard'; import {BoundingBox2D, LayoutType} from '@wandb/weave/common/types/media'; import { + Node, opAssetArtifactVersion, + opGetRunTag, + opRunConfig, replaceInputVariables, WBImage, } from '@wandb/weave/core'; @@ -32,14 +35,37 @@ type PanelImageProps = Panel2.PanelProps< PanelImageConfigType >; +const useClassLabels = (input: Node) => { + const {loading: runConfigLoading} = CGReact.useNodeValue( + opRunConfig({run: opGetRunTag({obj: input})}) + ); + + const {loading: runTagLoading, result: runTag} = CGReact.useNodeValue( + opGetRunTag({obj: input}) + ); + + return useMemo(() => { + if (!(runConfigLoading || runTagLoading) && runTag != null) { + const configSubset: {[key: string]: any} = JSON.parse( + runTag.configSubset + ); + return _.get(configSubset, '_wandb.value.mask/class_labels', {}); + } + return {}; + }, [runTag, runTagLoading, runConfigLoading]); +}; + const PanelImageConfig: FC = ({ config, updateConfig, input, }) => { + const classLabels = useClassLabels(input); + const {classSets, controls} = Controls.useImageControls( input.type, - config?.overlayControls + config?.overlayControls, + classLabels ); const updatedConfig = useMemo(() => { if (controls === config?.overlayControls) { @@ -108,11 +134,17 @@ const PanelImage: FC = ({config, input}) => { const image: WBImage = nodeValueQuery.result; + const classLabels = useClassLabels(inputNode); + const { maskControls: mergedMaskControls, boxControls: mergedBoxControls, classSets, - } = Controls.useImageControls(inputNode.type, config?.overlayControls); + } = Controls.useImageControls( + inputNode.type, + config?.overlayControls, + classLabels + ); const {imageBoxes, imageMasks, boxControls, maskControls} = useMemo(() => { const knownBoxKeys = image?.boxes != null ? _.keys(image.boxes) : []; diff --git a/weave-js/src/components/Panel2/controlsImage.ts b/weave-js/src/components/Panel2/controlsImage.ts index 760f36b552f4..9e98eb64aa40 100644 --- a/weave-js/src/components/Panel2/controlsImage.ts +++ b/weave-js/src/components/Panel2/controlsImage.ts @@ -100,44 +100,80 @@ export function createBoxControls( const defaultClassSetID = 'default'; +export type MaskClassLabels = { + key: string; + type: string; + value: {[key: string]: string}; +}; + +const toClassValue = (className: string, classKey: string) => { + const keyNOrNan = parseInt(classKey, 10); + const color = isNaN(keyNOrNan) + ? colorFromName(classKey) + : colorN(keyNOrNan, ROBIN16); + return {color, name: className}; +}; + export const useImageControls = ( inputType: Type, - currentControls?: OverlayControls + currentControls?: OverlayControls, + maskClassLabels?: {[key: string]: MaskClassLabels} ) => { const usableType = useMemo(() => { return nullableTaggableStrip(inputType) as ImageType; }, [inputType]); - // Images now only have a single class set (the default one) as the - // classes from all layers have been merged in the type system const classSets = useMemo(() => { - const classSet = _.mapValues(usableType.classMap ?? {}, (value, key) => { - const keyNOrNan = parseInt(key, 10); - const color = isNaN(keyNOrNan) - ? colorFromName(key) - : colorN(keyNOrNan, ROBIN16); - return {color, name: value}; - }) as ClassSetState['classes']; + const defaultClassSet = _.mapValues( + usableType.classMap ?? {}, + (className, classKey) => toClassValue(className, classKey) + ) as ClassSetState['classes']; + + const classSetsFromLabels = Object.entries(maskClassLabels ?? {}).reduce( + (acc, [maskKey, mask]) => { + const controlId = `mask-${maskKey.replace( + 'image_wandb_delimeter_', + '' + )}`; + acc[controlId] = { + classes: _.mapValues(mask.value, (labelName, labelKey) => + toClassValue(labelName, labelKey) + ), + }; + return acc; + }, + {} as ClassSetControls + ); + return { - [defaultClassSetID]: {classes: classSet}, + [defaultClassSetID]: {classes: defaultClassSet}, + ...classSetsFromLabels, } as ClassSetControls; - }, [usableType]); + }, [usableType, maskClassLabels]); const maskControls: {[key: string]: MaskControlState} = useMemo(() => { const maskLayers = usableType.maskLayers ?? {}; return _.fromPairs( _.keys(maskLayers).map(maskId => { const prefixedId = 'mask-' + maskId; - if (currentControls?.[prefixedId] != null) { + if ( + currentControls && + _.findKey( + currentControls, + control => control.classSetID === prefixedId + ) + ) { return [prefixedId, currentControls[prefixedId] as MaskControlState]; } - const classSubset = _.pick( - classSets[defaultClassSetID].classes, - ...maskLayers[maskId] - ); + let classSetId = defaultClassSetID; + if (prefixedId in classSets) { + classSetId = prefixedId; + } + const classSet = classSets[classSetId]; + const classSubset = _.pick(classSet.classes, ...maskLayers[maskId]); const newControl: MaskControlState = createMaskControls( prefixedId, - defaultClassSetID, + classSetId, {classes: classSubset} ); return [prefixedId, newControl];