diff --git a/README.md b/README.md index 2fac4f9..5ab71e0 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,36 @@ python -m pip install glob2 Then, download the model weights from this link and save it under `./checkpoints/`. +## Use torch hub to predcit normal +``` +import torch +import cv2 +import numpy as np + +# Load the normal predictor model from torch hub +normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True) + +# Load the input image using OpenCV +image = cv2.imread(args.input, cv2.IMREAD_COLOR) +h, w = image.shape[:2] + +# Use the model to infer the normal map from the input image +with torch.inference_mode(): + normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3) + normal = (normal + 1) / 2 # Convert values to the range [0, 1] + +# Convert the normal map to a displayable format +normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0) +normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR) + +# Save the output normal map to a file +cv2.imwrite(args.output, normal) +``` +If the network is unavailable to retrieve weights, you can use local weights for torch hub as shown below: +``` +normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./checkpoints/dsine.pt', trust_repo=True) +``` + ## Test on images * Run `python test.py` to generate predictions for the images under `./samples/img/`. The result will be saved under `./samples/output/`.