Skip to content

Commit

Permalink
Merge pull request #2 from hugoycj/main
Browse files Browse the repository at this point in the history
Implement torch hub loader for DSINE
  • Loading branch information
baegwangbin authored Apr 9, 2024
2 parents 15c708a + 647d64e commit b84433a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ python -m pip install glob2

Then, download the model weights from <a href="https://drive.google.com/drive/folders/1t3LMJIIrSnCGwOEf53Cyg0lkSXd3M4Hm?usp=drive_link" target="_blank">this link</a> and save it under `./checkpoints/`.

## Use torch hub to predcit normal
```
import torch
import cv2
import numpy as np
# Load the normal predictor model from torch hub
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
# Load the input image using OpenCV
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
# Use the model to infer the normal map from the input image
with torch.inference_mode():
normal = normal_predictor.infer_cv2(image)[0] # Output shape: (H, W, 3)
normal = (normal + 1) / 2 # Convert values to the range [0, 1]
# Convert the normal map to a displayable format
normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
# Save the output normal map to a file
cv2.imwrite(args.output, normal)
```
If the network is unavailable to retrieve weights, you can use local weights for torch hub as shown below:
```
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./checkpoints/dsine.pt', trust_repo=True)
```

## Test on images

* Run `python test.py` to generate predictions for the images under `./samples/img/`. The result will be saved under `./samples/output/`.
Expand Down
117 changes: 117 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import os
from typing import Optional
from torchvision import transforms
import numpy as np
import torch.nn.functional as F

dependencies = ["torch", "numpy", "geffnet"]
def _load_state_dict(local_file_path: Optional[str] = None):
if local_file_path is not None and os.path.exists(local_file_path):
# Load state_dict from local file
state_dict = torch.load(local_file_path, map_location=torch.device("cpu"))
else:
# Load state_dict from the default URL
file_name = "dsine.pt"
url = f"https://huggingface.co/camenduru/DSINE/resolve/main/dsine.pt"
state_dict = torch.hub.load_state_dict_from_url(url, file_name=file_name, map_location=torch.device("cpu"))

return state_dict['model']

class Predictor:
def __init__(self, model) -> None:
from models.dsine import DSINE
self.device = torch.device('cuda')
self.model = model
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def infer_cv2(self, image):
import cv2
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return self.infer_pil(image)

def infer_pil(self, img, intrins=None):
import utils.utils as utils
img = np.array(img).astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
_, _, orig_H, orig_W = img.shape

# zero-pad the input image so that both the width and height are multiples of 32
l, r, t, b = utils.pad_input(orig_H, orig_W)
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
img = self.transform(img)

if intrins is None:
intrins = utils.get_intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=self.device).unsqueeze(0)

intrins[:, 0, 2] += l
intrins[:, 1, 2] += t

with torch.no_grad():
pred_norm = self.model(img, intrins=intrins)[-1]
pred_norm = pred_norm[:, :, t:t+orig_H, l:l+orig_W]

# pred_norm_np = pred_norm.cpu().detach().numpy()[0,:,:,:].transpose(1, 2, 0) # (H, W, 3)
return pred_norm

def DSINE(local_file_path: Optional[str] = None):
from models import dsine

state_dict = _load_state_dict(local_file_path)
model = dsine.DSINE()
model.load_state_dict(state_dict, strict=True)
model.eval()
model = model.to(torch.device("cuda"))
model.pixel_coords = model.pixel_coords.to(torch.device("cuda"))

return Predictor(model)


def _test_run():
import argparse
import torch.nn.functional as F
import numpy as np

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input", "-i", type=str, required=True, help="input image file")
parser.add_argument("--output", "-o", type=str, required=True, help="output image file")
parser.add_argument("--remote", action="store_true", help="use remote repo")
parser.add_argument("--reload", action="store_true", help="reload remote repo")
parser.add_argument("--pil", action="store_true", help="use PIL instead of OpenCV")
args = parser.parse_args()

if not args.remote:
predictor = torch.hub.load(".", "DSINE", local_file_path='./checkpoints/dsine.pt',
source="local", trust_repo=True)
else:
predictor = torch.hub.load(".", "DSINE",
source="local", trust_repo=True)

if args.pil:
import PIL
import torchvision.transforms.functional as TF

image = PIL.Image.open(args.input).convert("RGB")
h, w = image.height, image.width
with torch.inference_mode():
normal = predictor.infer_pil(image)[0] # (H, W, 3)
normal = (normal + 1) / 2

normal = TF.to_pil_image(normal.cpu())
normal.save(args.output)

else:
import cv2
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
with torch.inference_mode():
normal = predictor.infer_cv2(image)[0] # (H, W, 3)
normal = (normal + 1) / 2

normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
cv2.imwrite(args.output, normal)


if __name__ == "__main__":
_test_run()

0 comments on commit b84433a

Please sign in to comment.