Skip to content

Commit

Permalink
Fix input scaling in centered-instance model
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 16, 2024
1 parent fff8761 commit a003f29
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info):
# These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
'css/tabs.css',
"css/tabs.css",
]

# Custom sidebar templates, must be a dictionary that maps document names
Expand Down
1 change: 1 addition & 0 deletions sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def make_viz_pipeline(self, data_provider: Provider) -> Pipeline:
provider=data_provider,
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
pipeline += Resizer.from_config(self.data_config.preprocessing)
pipeline += InstanceCentroidFinder.from_config(
self.data_config.instance_cropping,
skeletons=self.data_config.labels.skeletons,
Expand Down
41 changes: 35 additions & 6 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,18 @@ class CentroidCropGroundTruth(tf.keras.layers.Layer):
Attributes:
crop_size: The length of the square box to extract around each centroid.
input_scale: Float indicating if the images should be resized before being
passed to the model.
"""

def __init__(self, crop_size: int):
def __init__(
self,
crop_size: int,
input_scale: float = 1.0,
):
super().__init__()
self.crop_size = crop_size
self.input_scale = input_scale

def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Return the ground truth instance crops.
Expand All @@ -758,6 +765,8 @@ def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""
# Pull out data from example.
full_imgs = example_gt["image"]
if self.input_scale != 1.0:
full_imgs = sleap.nn.data.resizing.resize_image(full_imgs, self.input_scale)
crop_sample_inds = example_gt["centroids"].value_rowids() # (n_peaks,)
n_peaks = tf.shape(crop_sample_inds)[0] # total number of peaks in the batch
centroid_points = example_gt["centroids"].flat_values # (n_peaks, 2)
Expand Down Expand Up @@ -927,11 +936,12 @@ def __init__(
self.ensure_grayscale = ensure_grayscale
self.ensure_float = ensure_float

def preprocess(self, imgs: tf.Tensor) -> tf.Tensor:
def preprocess(self, imgs: tf.Tensor, resize_img: bool = True) -> tf.Tensor:
"""Apply all preprocessing operations configured for this layer.
Args:
imgs: A batch of images as a tensor.
resize_img: Bool to indicate if the images should be resized.
Returns:
The input tensor after applying preprocessing operations. The tensor will
Expand All @@ -947,7 +957,7 @@ def preprocess(self, imgs: tf.Tensor) -> tf.Tensor:
if self.ensure_float:
imgs = sleap.nn.data.normalization.ensure_float(imgs)

if self.input_scale != 1.0:
if resize_img and self.input_scale != 1.0:
imgs = sleap.nn.data.resizing.resize_image(imgs, self.input_scale)

if self.pad_to_stride > 1:
Expand Down Expand Up @@ -1954,6 +1964,11 @@ class FindInstancePeaks(InferenceLayer):
centered instance confidence maps.
input_scale: Float indicating if the images should be resized before being
passed to the model.
resize_input_image: Bool indicating if the crops should be resized. If
`CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the
images are resized in the `CentroidCropGroundTruth` and this is set to `False`.
However, the output keypoints are adjusted to the actual scale with the
`input_scaling` argument.
output_stride: Output stride of the model, denoting the scale of the output
confidence maps relative to the images (after input scaling). This is used
for adjusting the peak coordinates to the image grid. This will be inferred
Expand Down Expand Up @@ -1984,6 +1999,7 @@ def __init__(
self,
keras_model: tf.keras.Model,
input_scale: float = 1.0,
resize_input_image: bool = True,
output_stride: Optional[int] = None,
peak_threshold: float = 0.2,
refinement: Optional[str] = "local",
Expand All @@ -1996,6 +2012,7 @@ def __init__(
super().__init__(
keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs
)
self.resize_input_image = resize_input_image
self.peak_threshold = peak_threshold
self.refinement = refinement
self.integral_patch_size = integral_patch_size
Expand Down Expand Up @@ -2093,7 +2110,7 @@ def call(
crop_sample_inds = tf.range(samples, dtype=tf.int32)

# Preprocess inputs (scaling, padding, colorspace, int to float).
crops = self.preprocess(crops)
crops = self.preprocess(crops, resize_img=self.resize_input_image)

# Network forward pass.
out = self.keras_model(crops)
Expand Down Expand Up @@ -2343,7 +2360,7 @@ def _initialize_inference_model(self):

if use_gt_centroid:
centroid_crop_layer = CentroidCropGroundTruth(
crop_size=self.confmap_config.data.instance_cropping.crop_size
crop_size=self.confmap_config.data.instance_cropping.crop_size,
)
else:
if use_gt_confmap:
Expand Down Expand Up @@ -2375,7 +2392,10 @@ def _initialize_inference_model(self):
refinement="integral" if self.integral_refinement else "local",
integral_patch_size=self.integral_patch_size,
return_confmaps=False,
resize_input_image=False,
)
if use_gt_centroid:
centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling

self.inference_model = TopDownInferenceModel(
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer
Expand Down Expand Up @@ -3831,6 +3851,11 @@ class TopDownMultiClassFindPeaks(InferenceLayer):
centered instance confidence maps and classification.
input_scale: Float indicating if the images should be resized before being
passed to the model.
resize_input_image: Bool indicating if the crops should be resized. If
`CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the
images are resized in the `CentroidCropGroundTruth` and this is set to `False`.
However, the output keypoints are adjusted to the actual scale with the
`input_scaling` argument.
output_stride: Output stride of the model, denoting the scale of the output
confidence maps relative to the images (after input scaling). This is used
for adjusting the peak coordinates to the image grid. This will be inferred
Expand Down Expand Up @@ -3872,6 +3897,7 @@ def __init__(
self,
keras_model: tf.keras.Model,
input_scale: float = 1.0,
resize_input_image: bool = True,
output_stride: Optional[int] = None,
peak_threshold: float = 0.2,
refinement: Optional[str] = "local",
Expand All @@ -3887,6 +3913,7 @@ def __init__(
super().__init__(
keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs
)
self.resize_input_image = resize_input_image
self.peak_threshold = peak_threshold
self.refinement = refinement
self.integral_patch_size = integral_patch_size
Expand Down Expand Up @@ -4004,7 +4031,7 @@ def call(
crop_sample_inds = tf.range(samples, dtype=tf.int32)

# Preprocess inputs (scaling, padding, colorspace, int to float).
crops = self.preprocess(crops)
crops = self.preprocess(crops, resize_img=self.resize_input_image)

# Network forward pass.
out = self.keras_model(crops)
Expand Down Expand Up @@ -4254,6 +4281,8 @@ def _initialize_inference_model(self):
integral_patch_size=self.integral_patch_size,
return_confmaps=False,
)
if use_gt_centroid:
centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling

self.inference_model = TopDownMultiClassInferenceModel(
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer
Expand Down

0 comments on commit a003f29

Please sign in to comment.