From 70a1c5a017917441f619f2cdcf18b01fb5ff8980 Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Wed, 26 Oct 2022 15:05:33 +0100 Subject: [PATCH] [Cherry-pick] Fixes for validation and uint8 (#119) --- export.py | 8 ++++++++ utils/sparse.py | 3 +++ val_onnx.py | 5 ++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/export.py b/export.py index 608a3edd1c14..bfc5c571c947 100644 --- a/export.py +++ b/export.py @@ -690,6 +690,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' num_export_samples=num_export_samples, save_dir= os.path.dirname(file), image_size=imgsz, + save_inputs_as_uint8=f[2] and _graph_has_uint8_inputs(f[2]) ) # TensorFlow Exports @@ -772,6 +773,13 @@ def parse_opt(known = False, skip_parse = False): return opt +def _graph_has_uint8_inputs(onnx_path): + import onnx + onnx_model = onnx.load(onnx_path) + # check if first model input has elem type 2 (uint8) + return onnx_model.graph.input[0].type.tensor_type.elem_type == 2 + + def main(opt): for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]): run(**vars(opt)) diff --git a/utils/sparse.py b/utils/sparse.py index 93fa6913a5f5..192248ec4515 100644 --- a/utils/sparse.py +++ b/utils/sparse.py @@ -266,6 +266,7 @@ def save_sample_inputs_outputs( num_export_samples=100, save_dir: Optional[str] = None, image_size: int=640, + save_inputs_as_uint8: bool = False, ): save_dir = save_dir or "" if not dataloader: @@ -308,6 +309,8 @@ def save_sample_inputs_outputs( file_idx = f"{exported_samples}".zfill(4) sample_input_filename = os.path.join(f"{sample_in_dir}", f"inp-{file_idx}.npz") + if save_inputs_as_uint8: + sample_in = (255 * sample_in).to(dtype=torch.uint8) numpy.savez(sample_input_filename, sample_in) sample_output_filename = os.path.join(f"{sample_out_dir}", f"out-{file_idx}.npz") diff --git a/val_onnx.py b/val_onnx.py index eea06f55c873..9bcbf46f3f71 100644 --- a/val_onnx.py +++ b/val_onnx.py @@ -310,7 +310,10 @@ def run( # Inference out = yolo_pipeline( - images=[im.numpy()], iou_thres=iou_thres, conf_thres=conf_thres + images=[im.numpy()], + iou_thres=iou_thres, + conf_thres=conf_thres, + multi_label=True ) # inference, loss outputs