diff --git a/scripts/process_image.py b/scripts/process_image.py index eb87a6d..03d82e1 100644 --- a/scripts/process_image.py +++ b/scripts/process_image.py @@ -27,7 +27,16 @@ img = skimage.io.imread(cfg.img_path) -img = xrv.datasets.normalize(img, 255, reshape=True) +img = xrv.datasets.normalize(img, 255) + +# Check that images are 2D arrays +if len(img.shape) > 2: + img = img[:, :, 0] +if len(img.shape) < 2: + print("error, dimension lower than 2 for image") + +# Add color channel +img = img[None, :, :] transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224)])