Skip to content

Commit

Permalink
fix(app): obtain mask-specific class labels from the run config
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholaspun-wandb committed Feb 21, 2025
1 parent cef0286 commit 8084a47
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 20 deletions.
36 changes: 34 additions & 2 deletions weave-js/src/components/Panel2/PanelImage.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import {BoundingBoxSliderControl} from '@wandb/weave/common/components/MediaCard';
import {BoundingBox2D, LayoutType} from '@wandb/weave/common/types/media';
import {
Node,
opAssetArtifactVersion,
opGetRunTag,
opRunConfig,
replaceInputVariables,
WBImage,
} from '@wandb/weave/core';
Expand Down Expand Up @@ -32,14 +35,37 @@ type PanelImageProps = Panel2.PanelProps<
PanelImageConfigType
>;

const useClassLabels = (input: Node<typeof inputType>) => {
const {loading: runConfigLoading} = CGReact.useNodeValue(
opRunConfig({run: opGetRunTag({obj: input})})
);

const {loading: runTagLoading, result: runTag} = CGReact.useNodeValue(
opGetRunTag({obj: input})
);

return useMemo(() => {
if (!(runConfigLoading || runTagLoading) && runTag != null) {
const configSubset: {[key: string]: any} = JSON.parse(
runTag.configSubset
);
return _.get(configSubset, '_wandb.value.mask/class_labels', {});
}
return {};
}, [runTag, runTagLoading, runConfigLoading]);
};

const PanelImageConfig: FC<PanelImageProps> = ({
config,
updateConfig,
input,
}) => {
const classLabels = useClassLabels(input);

const {classSets, controls} = Controls.useImageControls(
input.type,
config?.overlayControls
config?.overlayControls,
classLabels
);
const updatedConfig = useMemo(() => {
if (controls === config?.overlayControls) {
Expand Down Expand Up @@ -108,11 +134,17 @@ const PanelImage: FC<PanelImageProps> = ({config, input}) => {

const image: WBImage = nodeValueQuery.result;

const classLabels = useClassLabels(inputNode);

const {
maskControls: mergedMaskControls,
boxControls: mergedBoxControls,
classSets,
} = Controls.useImageControls(inputNode.type, config?.overlayControls);
} = Controls.useImageControls(
inputNode.type,
config?.overlayControls,
classLabels
);

const {imageBoxes, imageMasks, boxControls, maskControls} = useMemo(() => {
const knownBoxKeys = image?.boxes != null ? _.keys(image.boxes) : [];
Expand Down
72 changes: 54 additions & 18 deletions weave-js/src/components/Panel2/controlsImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,44 +100,80 @@ export function createBoxControls(

const defaultClassSetID = 'default';

export type MaskClassLabels = {
key: string;
type: string;
value: {[key: string]: string};
};

const toClassValue = (className: string, classKey: string) => {
const keyNOrNan = parseInt(classKey, 10);
const color = isNaN(keyNOrNan)
? colorFromName(classKey)
: colorN(keyNOrNan, ROBIN16);
return {color, name: className};
};

export const useImageControls = (
inputType: Type,
currentControls?: OverlayControls
currentControls?: OverlayControls,
maskClassLabels?: {[key: string]: MaskClassLabels}
) => {
const usableType = useMemo(() => {
return nullableTaggableStrip(inputType) as ImageType;
}, [inputType]);

// Images now only have a single class set (the default one) as the
// classes from all layers have been merged in the type system
const classSets = useMemo(() => {
const classSet = _.mapValues(usableType.classMap ?? {}, (value, key) => {
const keyNOrNan = parseInt(key, 10);
const color = isNaN(keyNOrNan)
? colorFromName(key)
: colorN(keyNOrNan, ROBIN16);
return {color, name: value};
}) as ClassSetState['classes'];
const defaultClassSet = _.mapValues(
usableType.classMap ?? {},
(className, classKey) => toClassValue(className, classKey)
) as ClassSetState['classes'];

const classSetsFromLabels = Object.entries(maskClassLabels ?? {}).reduce(
(acc, [maskKey, mask]) => {
const controlId = `mask-${maskKey.replace(
'image_wandb_delimeter_',
''
)}`;
acc[controlId] = {
classes: _.mapValues(mask.value, (labelName, labelKey) =>
toClassValue(labelName, labelKey)
),
};
return acc;
},
{} as ClassSetControls
);

return {
[defaultClassSetID]: {classes: classSet},
[defaultClassSetID]: {classes: defaultClassSet},
...classSetsFromLabels,
} as ClassSetControls;
}, [usableType]);
}, [usableType, maskClassLabels]);

const maskControls: {[key: string]: MaskControlState} = useMemo(() => {
const maskLayers = usableType.maskLayers ?? {};
return _.fromPairs(
_.keys(maskLayers).map(maskId => {
const prefixedId = 'mask-' + maskId;
if (currentControls?.[prefixedId] != null) {
if (
currentControls &&
_.findKey(
currentControls,
control => control.classSetID === prefixedId
)
) {
return [prefixedId, currentControls[prefixedId] as MaskControlState];
}
const classSubset = _.pick(
classSets[defaultClassSetID].classes,
...maskLayers[maskId]
);
let classSetId = defaultClassSetID;
if (prefixedId in classSets) {
classSetId = prefixedId;
}
const classSet = classSets[classSetId];
const classSubset = _.pick(classSet.classes, ...maskLayers[maskId]);
const newControl: MaskControlState = createMaskControls(
prefixedId,
defaultClassSetID,
classSetId,
{classes: classSubset}
);
return [prefixedId, newControl];
Expand Down

0 comments on commit 8084a47

Please sign in to comment.