diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 6a9d05806..9e07c0a25 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -99,9 +99,9 @@ optional arguments: -e [EXPORT_PATH], --export_path [EXPORT_PATH] Path to output directory where the frozen model will be exported to. Defaults to a folder named 'exported_model'. - -u, --unrag UNRAG - Convert ragged tensors into regular tensors with NaN padding. - Defaults to True. + -r, --ragged RAGGED + Keep tensors ragged if present. If ommited, convert + ragged tensors into regular tensors with NaN padding. -n, --max_instances MAX_INSTANCES Limit maximum number of instances in multi-instance models. Not available for ID models. Defaults to None. diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 6d7d24f8c..0cabc91bb 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4939,7 +4939,7 @@ def export_cli(args: Optional[list] = None): export_model( args.models, args.export_path, - unrag_outputs=args.unrag, + unrag_outputs=(not args.ragged), max_instances=args.max_instances, ) @@ -4971,13 +4971,13 @@ def _make_export_cli_parser() -> argparse.ArgumentParser: ), ) parser.add_argument( - "-u", - "--unrag", + "-r", + "--ragged", action="store_true", - default=True, + default=False, help=( - "Convert ragged tensors into regular tensors with NaN padding. " - "Defaults to True." + "Keep tensors ragged if present. If ommited, convert ragged tensors" + " into regular tensors with NaN padding." ), ) parser.add_argument( diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fe848bb1c..dedf0d324 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -50,6 +50,7 @@ _make_tracker_from_cli, main as sleap_track, export_cli as sleap_export, + _make_export_cli_parser, ) from sleap.nn.tracking import ( MatchedFrameInstance, @@ -925,7 +926,7 @@ def test_load_model(resize_input_shape, model_fixture_name, request): predictor = load_model(model_path, resize_input_layer=resize_input_shape) # Determine predictor type - for (fname, mname, ptype, ishape) in fname_mname_ptype_ishape: + for fname, mname, ptype, ishape in fname_mname_ptype_ishape: if fname in model_fixture_name: expected_model_name = mname expected_predictor_type = ptype @@ -966,7 +967,6 @@ def test_topdown_multi_size_inference( def test_ensure_numpy( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp ): - model = load_model([min_centroid_model_path, min_centered_instance_model_path]) # each frame has same number of instances @@ -1037,7 +1037,6 @@ def test_ensure_numpy( def test_centroid_inference(): - xv, yv = make_grid_vectors(image_height=12, image_width=12, output_stride=1) points = tf.cast([[[1.75, 2.75]], [[3.75, 4.75]], [[5.75, 6.75]]], tf.float32) cms = tf.expand_dims(make_multi_confmaps(points, xv, yv, sigma=1.5), axis=0) @@ -1093,7 +1092,6 @@ def test_centroid_inference(): def export_frozen_graph(model, preds, output_path): - tensors = {} for key, val in preds.items(): @@ -1120,7 +1118,6 @@ def export_frozen_graph(model, preds, output_path): info = json.load(json_file) for tensor_info in info["frozen_model_inputs"] + info["frozen_model_outputs"]: - saved_name = ( tensor_info.split("Tensor(")[1].split(", shape")[0].replace('"', "") ) @@ -1137,7 +1134,6 @@ def export_frozen_graph(model, preds, output_path): def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): - single_instance_model = tf.keras.models.load_model( min_single_instance_robot_model_path + "/best_model.h5", compile=False ) @@ -1152,7 +1148,6 @@ def test_single_instance_save(min_single_instance_robot_model_path, tmp_path): def test_centroid_save(min_centroid_model_path, tmp_path): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1171,7 +1166,6 @@ def test_centroid_save(min_centroid_model_path, tmp_path): def test_topdown_save( min_centroid_model_path, min_centered_instance_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1195,7 +1189,6 @@ def test_topdown_save( def test_topdown_id_save( min_centroid_model_path, min_topdown_multiclass_model_path, min_labels_slp, tmp_path ): - centroid_model = tf.keras.models.load_model( min_centroid_model_path + "/best_model.h5", compile=False ) @@ -1217,7 +1210,6 @@ def test_topdown_id_save( def test_single_instance_predictor_save(min_single_instance_robot_model_path, tmp_path): - # directly initialize predictor predictor = SingleInstancePredictor.from_trained_models( min_single_instance_robot_model_path, resize_input_layer=False @@ -1254,10 +1246,33 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm ) +def test_make_export_cli(): + models_path = r"psuedo/models/path" + export_path = r"psuedo/test/path" + max_instances = 5 + + parser = _make_export_cli_parser() + + # Test default values + args = None + args, _ = parser.parse_known_args(args=args) + assert args.models is None + assert args.export_path == "exported_model" + assert not args.ragged + assert args.max_instances is None + + # Test all arguments + cmd = f"-m {models_path} -e {export_path} -r -n {max_instances}" + args, _ = parser.parse_known_args(args=cmd.split()) + assert args.models == [models_path] + assert args.export_path == export_path + assert args.ragged + assert args.max_instances == max_instances + + def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1300,7 +1315,6 @@ def test_topdown_predictor_save( def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path ): - # directly initialize predictor predictor = TopDownMultiClassPredictor.from_trained_models( centroid_model_path=min_centroid_model_path, @@ -1478,7 +1492,6 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): # Run tracking on subset of frames using psuedo-implementation of # sleap.nn.tracking.run_tracker for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None @@ -1522,7 +1535,6 @@ def test_max_tracks_matching_queue( frames = sorted(labels.labeled_frames, key=lambda lf: lf.frame_idx) for lf in frames[:20]: - # Clear the tracks for inst in lf.instances: inst.track = None