diff --git a/docs/conf.py b/docs/conf.py index a429c7928..796497f6b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info): # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) html_css_files = [ - 'css/tabs.css', + "css/tabs.css", ] # Custom sidebar templates, must be a dictionary that maps document names diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index 2e334456a..d15c4491f 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -775,6 +775,7 @@ def make_viz_pipeline(self, data_provider: Provider) -> Pipeline: provider=data_provider, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, @@ -1250,6 +1251,7 @@ def make_viz_pipeline(self, data_provider: Provider) -> Pipeline: provider=data_provider, ) pipeline += Normalizer.from_config(self.data_config.preprocessing) + pipeline += Resizer.from_config(self.data_config.preprocessing) pipeline += InstanceCentroidFinder.from_config( self.data_config.instance_cropping, skeletons=self.data_config.labels.skeletons, diff --git a/sleap/nn/data/providers.py b/sleap/nn/data/providers.py index 9e93d0b18..a5c899ebe 100644 --- a/sleap/nn/data/providers.py +++ b/sleap/nn/data/providers.py @@ -1,10 +1,12 @@ """Data providers for pipeline I/O.""" import numpy as np +import sleap.instance import tensorflow as tf import attr from typing import Text, Optional, List, Sequence, Union, Tuple import sleap +from sleap.instance import Instance @attr.s(auto_attribs=True) @@ -197,6 +199,23 @@ def py_fetch_lf(ind): raw_image = lf.image raw_image_size = np.array(raw_image.shape).astype("int32") + height, width = raw_image_size + + instances = [] + for instance in lf.instances: + pts = instance.numpy() + # negative coords + pts[pts < 0] = np.NaN + + # coordinates outside img frame + pts[:, 0][pts[:, 0] > height - 1] = np.NaN + pts[:, 1][pts[:, 1] > width - 1] = np.NaN + + # remove all nans + pts = pts[~np.isnan(pts).any(axis=1), :] + + instances.append(Instance.from_numpy(pts, lf.skeleton, lf.track)) + lf.instances = instances if self.user_instances_only: insts = lf.user_instances diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 14e0d5c6f..5b65c825d 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -727,11 +727,18 @@ class CentroidCropGroundTruth(tf.keras.layers.Layer): Attributes: crop_size: The length of the square box to extract around each centroid. + input_scale: Float indicating if the images should be resized before being + passed to the model. """ - def __init__(self, crop_size: int): + def __init__( + self, + crop_size: int, + input_scale: float = 1.0, + ): super().__init__() self.crop_size = crop_size + self.input_scale = input_scale def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """Return the ground truth instance crops. @@ -758,6 +765,9 @@ def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: """ # Pull out data from example. full_imgs = example_gt["image"] + if self.input_scale != 1.0: + full_imgs = sleap.nn.data.resizing.resize_image(full_imgs, self.input_scale) + example_gt["centroids"] *= self.input_scale crop_sample_inds = example_gt["centroids"].value_rowids() # (n_peaks,) n_peaks = tf.shape(crop_sample_inds)[0] # total number of peaks in the batch centroid_points = example_gt["centroids"].flat_values # (n_peaks, 2) @@ -927,11 +937,12 @@ def __init__( self.ensure_grayscale = ensure_grayscale self.ensure_float = ensure_float - def preprocess(self, imgs: tf.Tensor) -> tf.Tensor: + def preprocess(self, imgs: tf.Tensor, resize_img: bool = True) -> tf.Tensor: """Apply all preprocessing operations configured for this layer. Args: imgs: A batch of images as a tensor. + resize_img: Bool to indicate if the images should be resized. Returns: The input tensor after applying preprocessing operations. The tensor will @@ -947,7 +958,7 @@ def preprocess(self, imgs: tf.Tensor) -> tf.Tensor: if self.ensure_float: imgs = sleap.nn.data.normalization.ensure_float(imgs) - if self.input_scale != 1.0: + if resize_img and self.input_scale != 1.0: imgs = sleap.nn.data.resizing.resize_image(imgs, self.input_scale) if self.pad_to_stride > 1: @@ -1954,6 +1965,11 @@ class FindInstancePeaks(InferenceLayer): centered instance confidence maps. input_scale: Float indicating if the images should be resized before being passed to the model. + resize_input_image: Bool indicating if the crops should be resized. If + `CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the + images are resized in the `CentroidCropGroundTruth` and this is set to `False`. + However, the output keypoints are adjusted to the actual scale with the + `input_scaling` argument. output_stride: Output stride of the model, denoting the scale of the output confidence maps relative to the images (after input scaling). This is used for adjusting the peak coordinates to the image grid. This will be inferred @@ -1984,6 +2000,7 @@ def __init__( self, keras_model: tf.keras.Model, input_scale: float = 1.0, + resize_input_image: bool = True, output_stride: Optional[int] = None, peak_threshold: float = 0.2, refinement: Optional[str] = "local", @@ -1996,6 +2013,7 @@ def __init__( super().__init__( keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs ) + self.resize_input_image = resize_input_image self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size @@ -2093,7 +2111,7 @@ def call( crop_sample_inds = tf.range(samples, dtype=tf.int32) # Preprocess inputs (scaling, padding, colorspace, int to float). - crops = self.preprocess(crops) + crops = self.preprocess(crops, resize_img=self.resize_input_image) # Network forward pass. out = self.keras_model(crops) @@ -2343,7 +2361,7 @@ def _initialize_inference_model(self): if use_gt_centroid: centroid_crop_layer = CentroidCropGroundTruth( - crop_size=self.confmap_config.data.instance_cropping.crop_size + crop_size=self.confmap_config.data.instance_cropping.crop_size, ) else: if use_gt_confmap: @@ -2375,7 +2393,10 @@ def _initialize_inference_model(self): refinement="integral" if self.integral_refinement else "local", integral_patch_size=self.integral_patch_size, return_confmaps=False, + resize_input_image=False, ) + if use_gt_centroid: + centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling self.inference_model = TopDownInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer @@ -3831,6 +3852,11 @@ class TopDownMultiClassFindPeaks(InferenceLayer): centered instance confidence maps and classification. input_scale: Float indicating if the images should be resized before being passed to the model. + resize_input_image: Bool indicating if the crops should be resized. If + `CentroidCropGroundTruth` is used along with `FindInstancePeaks`, then the + images are resized in the `CentroidCropGroundTruth` and this is set to `False`. + However, the output keypoints are adjusted to the actual scale with the + `input_scaling` argument. output_stride: Output stride of the model, denoting the scale of the output confidence maps relative to the images (after input scaling). This is used for adjusting the peak coordinates to the image grid. This will be inferred @@ -3872,6 +3898,7 @@ def __init__( self, keras_model: tf.keras.Model, input_scale: float = 1.0, + resize_input_image: bool = True, output_stride: Optional[int] = None, peak_threshold: float = 0.2, refinement: Optional[str] = "local", @@ -3887,6 +3914,7 @@ def __init__( super().__init__( keras_model=keras_model, input_scale=input_scale, pad_to_stride=1, **kwargs ) + self.resize_input_image = resize_input_image self.peak_threshold = peak_threshold self.refinement = refinement self.integral_patch_size = integral_patch_size @@ -4004,7 +4032,7 @@ def call( crop_sample_inds = tf.range(samples, dtype=tf.int32) # Preprocess inputs (scaling, padding, colorspace, int to float). - crops = self.preprocess(crops) + crops = self.preprocess(crops, resize_img=self.resize_input_image) # Network forward pass. out = self.keras_model(crops) @@ -4253,7 +4281,10 @@ def _initialize_inference_model(self): refinement="integral" if self.integral_refinement else "local", integral_patch_size=self.integral_patch_size, return_confmaps=False, + resize_input_image=False, ) + if use_gt_centroid: + centroid_crop_layer.input_scale = cfg.data.preprocessing.input_scaling self.inference_model = TopDownMultiClassInferenceModel( centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer diff --git a/sleap/nn/training.py b/sleap/nn/training.py index c3692637c..f56d8cf46 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -1319,6 +1319,7 @@ def _setup_visualization(self): peak_threshold=0.2, refinement="local", return_confmaps=True, + resize_input_image=False, ) def visualize_example(example): @@ -1759,6 +1760,7 @@ def _setup_visualization(self): peak_threshold=0.2, refinement="local", return_confmaps=True, + resize_input_image=False, ) def visualize_example(example):