Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoycj committed Mar 6, 2024
1 parent c2b7936 commit 647d64e
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ python -m pip install glob2

Then, download the model weights from <a href="https://drive.google.com/drive/folders/1t3LMJIIrSnCGwOEf53Cyg0lkSXd3M4Hm?usp=drive_link" target="_blank">this link</a> and save it under `./checkpoints/`.

## Use torch hub to predcit normal
```
import torch
import cv2
import numpy as np
# Load the normal predictor model from torch hub
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
# Load the input image using OpenCV
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
# Use the model to infer the normal map from the input image
with torch.inference_mode():
normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3)
normal = (normal + 1) / 2 # Convert values to the range [0, 1]
# Convert the normal map to a displayable format
normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
# Save the output normal map to a file
cv2.imwrite(args.output, normal)
```
If the network is unavailable to retrieve weights, you can use local weights for torch hub as shown below:
```
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./checkpoints/dsine.pt', trust_repo=True)
```

## Test on images

* Run `python test.py` to generate predictions for the images under `./samples/img/`. The result will be saved under `./samples/output/`.
Expand Down

0 comments on commit 647d64e

Please sign in to comment.