Skip to content

Commit

Permalink
[Cherry-pick] Fixes for validation and uint8 (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored Oct 26, 2022
1 parent 67a661a commit 70a1c5a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
8 changes: 8 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
num_export_samples=num_export_samples,
save_dir= os.path.dirname(file),
image_size=imgsz,
save_inputs_as_uint8=f[2] and _graph_has_uint8_inputs(f[2])
)

# TensorFlow Exports
Expand Down Expand Up @@ -772,6 +773,13 @@ def parse_opt(known = False, skip_parse = False):
return opt


def _graph_has_uint8_inputs(onnx_path):
import onnx
onnx_model = onnx.load(onnx_path)
# check if first model input has elem type 2 (uint8)
return onnx_model.graph.input[0].type.tensor_type.elem_type == 2


def main(opt):
for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
run(**vars(opt))
Expand Down
3 changes: 3 additions & 0 deletions utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def save_sample_inputs_outputs(
num_export_samples=100,
save_dir: Optional[str] = None,
image_size: int=640,
save_inputs_as_uint8: bool = False,
):
save_dir = save_dir or ""
if not dataloader:
Expand Down Expand Up @@ -308,6 +309,8 @@ def save_sample_inputs_outputs(
file_idx = f"{exported_samples}".zfill(4)

sample_input_filename = os.path.join(f"{sample_in_dir}", f"inp-{file_idx}.npz")
if save_inputs_as_uint8:
sample_in = (255 * sample_in).to(dtype=torch.uint8)
numpy.savez(sample_input_filename, sample_in)

sample_output_filename = os.path.join(f"{sample_out_dir}", f"out-{file_idx}.npz")
Expand Down
5 changes: 4 additions & 1 deletion val_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ def run(

# Inference
out = yolo_pipeline(
images=[im.numpy()], iou_thres=iou_thres, conf_thres=conf_thres
images=[im.numpy()],
iou_thres=iou_thres,
conf_thres=conf_thres,
multi_label=True
)

# inference, loss outputs
Expand Down

0 comments on commit 70a1c5a

Please sign in to comment.