Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export to onnx #15

Open
limewodemaya opened this issue Aug 30, 2024 · 0 comments
Open

Export to onnx #15

limewodemaya opened this issue Aug 30, 2024 · 0 comments

Comments

@limewodemaya
Copy link

limewodemaya commented Aug 30, 2024

Dear author, thank you for your excellent work. My device takes only 0.08 seconds to run each image based on pytorch. But after converting to onnx format, it takes 0.8 seconds to run each image. My conversion code is as follows.

import os
import sys
import glob
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import onnx
import onnxruntime as ort
import time
import sys
sys.path.append('../../')
import utils.utils as utils
import projects.dsine.config as config
from utils.projection import intrins_from_fov, intrins_from_txt

if __name__ == '__main__':
    device = torch.device('cuda')
    args = config.get_args(test=True)
    assert os.path.exists(args.ckpt_path)

    if args.NNET_architecture == 'v00':
        from models.dsine.v00 import DSINE_v00 as DSINE
    elif args.NNET_architecture == 'v01':
        from models.dsine.v01 import DSINE_v01 as DSINE
    elif args.NNET_architecture == 'v02':
        from models.dsine.v02 import DSINE_v02 as DSINE
    elif args.NNET_architecture == 'v02_kappa':
        from models.dsine.v02_kappa import DSINE_v02_kappa as DSINE
    else:
        raise Exception('invalid arch')

    model = DSINE(args).to(device)
    model = utils.load_checkpoint(args.ckpt_path, model)

    model.eval()

    # Select a sample image for ONNX export
    sample_img_path = './samples/img/sample.png'  
    ext = os.path.splitext(sample_img_path)[1]
    img = Image.open(sample_img_path).convert('RGB')
    img = np.array(img).astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)

    # pad input
    _, _, orig_H, orig_W = img.shape
    lrtb = utils.get_padding(orig_H, orig_W)
    img = F.pad(img, lrtb, mode="constant", value=0.0)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    img = normalize(img)

    # get intrinsics
    intrins_path = sample_img_path.replace(ext, '.txt')
    if os.path.exists(intrins_path):
        intrins = intrins_from_txt(intrins_path, device=device).unsqueeze(0)
    else:
        intrins = intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=device).unsqueeze(0)
    intrins[:, 0, 2] += lrtb[0]
    intrins[:, 1, 2] += lrtb[2]

    # Export the model to ONNX
    onnx_path = './model_op17.onnx'
    output_normals = model(img, intrins=intrins)[-1]
    torch.onnx.export(
        model,
        (img, intrins),
        onnx_path,
        export_params=True,
        opset_version=17,  
        do_constant_folding=True,
        input_names=['input_image', 'input_intrinsics'],  
        output_names=['output_normals'],  
        dynamic_axes={'input_image': {0: 'batch_size'}, 'output_normals': {0: 'batch_size'}}  
    )
    print(f"Model has been converted to ONNX and saved at {onnx_path}")

    img_paths = glob.glob('./samples/img/*.png') + glob.glob('./samples/img/*.jpg')
    img_paths.sort()
    os.makedirs('./samples/output/', exist_ok=True)

    onnx_path = './model_op17.onnx'
    ort_session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])

    with torch.no_grad():
        for img_path in img_paths:
            print(img_path)
            ext = os.path.splitext(img_path)[1]
            img = Image.open(img_path).convert('RGB')
            img = np.array(img).astype(np.float32) / 255.0
            img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)

            # pad input
            _, _, orig_H, orig_W = img.shape
            lrtb = utils.get_padding(orig_H, orig_W)
            img = F.pad(img, lrtb, mode="constant", value=0.0)
            img = normalize(img)

            # get intrinsics
            intrins_path = img_path.replace(ext, '.txt')
            if os.path.exists(intrins_path):
                intrins = intrins_from_txt(intrins_path, device=device).unsqueeze(0)
            else:
                intrins = intrins_from_fov(new_fov=60.0, H=orig_H, W=orig_W, device=device).unsqueeze(0)
            intrins[:, 0, 2] += lrtb[0]
            intrins[:, 1, 2] += lrtb[2]

            # ONNX inference
            start_time = time.time()
            ort_inputs = {
                'input_image': img.cpu().numpy(),
                'input_intrinsics': intrins.cpu().numpy()
            }
            ort_outs = ort_session.run(['output_normals'], ort_inputs)
            pred_norm = torch.tensor(ort_outs[0]).to(device)
            end_time = time.time()
            print(f'inference time: {end_time - start_time:.4f} seconds')

            pred_norm = pred_norm[:, :, lrtb[2]:lrtb[2]+orig_H, lrtb[0]:lrtb[0]+orig_W]

            # save to output folder
            target_path = img_path.replace('/img/', '/output/').replace(ext, '.png')

            pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy()
            pred_norm = (((pred_norm + 1) * 0.5) * 255).astype(np.uint8)
            im = Image.fromarray(pred_norm[0,...])
            im.save(target_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant