Skip to content

Commit

Permalink
Merge pull request #13 from Synthesis-AI-Dev/segmentation_memory_opti…
Browse files Browse the repository at this point in the history
…mization

Optimized memory usage during segmentation.
  • Loading branch information
mlmcgoogan authored Aug 13, 2021
2 parents 40a360b + c39c85c commit 39363d3
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions exr_info/cryptomatte.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,10 @@ def get_masks_for_all_objs(self, crypto_def_name: str) -> OrderedDict:
# The objects in manifest are sorted alphabetically to maintain some order.
# Each obj is assigned an unique ID (per image) for the mask
obj_names = sorted(manifest.keys())
obj_masks = OrderedDict()
for obj_name in obj_names:
obj_hex_id = manifest[obj_name]
mask = self.get_mask_for_id(obj_hex_id, channels_arr, level)
obj_masks[obj_name] = mask

return obj_masks
yield obj_name, mask

def get_combined_mask(self, crypto_def_name: str) -> Tuple[np.ndarray, Dict[str, int]]:
"""
Expand All @@ -131,19 +128,26 @@ def get_combined_mask(self, crypto_def_name: str) -> Tuple[np.ndarray, Dict[str,
"""
obj_masks = self.get_masks_for_all_objs(crypto_def_name)

# Create a map of obj names to ids
best = None
total = None
mask_combined = None
name_to_mask_id_map = OrderedDict()
name_to_mask_id_map["background"] = 0 # Background is always class 0
obj_names = obj_masks.keys()
for idx, obj_name in enumerate(obj_names):
name_to_mask_id_map[obj_name] = idx + 1

# Combine all the masks into single mask without anti-aliasing for semantic segmentation
masks = np.stack(list(obj_masks.values()), axis=0) # Shape: [N, H, W]
background_mask = 255 - masks.sum(axis=0)
masks = np.concatenate((np.expand_dims(background_mask, 0), masks), axis=0)
mask_combined = masks.argmax(axis=0)
mask_combined = mask_combined.astype(np.uint16)
for (idx, (obj_name, obj_mask)) in enumerate(obj_masks):
name_to_mask_id_map[obj_name] = idx + 1
if mask_combined is None:
mask_combined = np.zeros_like(obj_mask, dtype=np.uint16)
if best is None:
best = np.zeros_like(obj_mask)
if total is None:
total = np.zeros_like(obj_mask)

total += obj_mask
mask_combined[obj_mask > best] = idx + 1
best = np.max([best, obj_mask], axis=0)

mask_combined[255 - total > best] = 0

return mask_combined, name_to_mask_id_map

Expand Down

0 comments on commit 39363d3

Please sign in to comment.