Skip to content

Object detection training and inference pipeline using CenterNet algorithm. Built on PyTorch Lightning and ONNX Runtime.

License

Notifications You must be signed in to change notification settings

ugurcanozalp/centernet-lightning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CenterNet Object Detection

Object detection training/inference pipeline using CenterNet algorithm. Training with Pytorch Lightning and inference with ONNX Runtime. Fow now, only ResNet backbones are available, but others such as HourGlass network will be added soon.

Installation

git clone https://github.com/ugurcanozalp/centernet-lightning
cd centernet-lightning
pip install -e .

Training

Pytorch Lightning is used for training. First, place data files into data folder. To train resnet18 model with bolts and nuts dataset, run the following command,

python -m scripts.train_bolts_nuts --gpus 1 --max_epochs 20

For testing the model, run the following command.

python -m scripts.test_bolts_nuts --gpus 1

Example of training resnet18 with bolts and nuts dataset using default hyperparameters,

Training Loss MAP over Validation Dataset

Inference

For inference, you should use one of checkpoints (for below example, checkpoints folder)

from centernet import CenterNet
from PIL import Image
model = CenterNet.load_from_checkpoint("checkpoints/centernet_resnet18.pt.ckpt") # Load pretrained model.
image = Image.open("images/test_0336.jpg") 
batch = model.preprocess(image).unsqueeze(0) # convert to specific size and torch tensor, add batch dimension
batch_ids, boxes, scores, labels = model(batch)

Onnx Runtime Inference

If you want to use onnx runtime, export the model using export.py.

python -m scripts.export --ckpt checkpoints/centernet_resnet18.pt.ckpt --quantized

Then, you do inference as follows.

import numpy as np
from PIL import Image
from centernet import ObjectDetector
image = Image.open("images/test_0336.jpg") 
image = np.asarray(image) # PIL image to numpy array
detector = ObjectDetector("deployments/centernet_resnet18_quantized.onnx")
batch_ids, boxes, scores, labels = detector([image])

ONNX real-time demo

You can use demo.py script for this purpose. Modify it according to your purposes.

python demo.py

Current trained model have following output in test set (bolts and nuts).

Expected Output

Possible future improvements

This package is a new one, so some features extra features to be added.

  • Implement more augmentations in dataset classes. (Waiting for new version of torchvision)
  • Handle COCO dataset
  • Implement other backbones other than ResNet, like Hourglass, swin transformer etc.

References

Citation

@inproceedings{zhou2019objects,
    title={Objects as Points},
    author={Zhou, Xingyi and Wang, Dequan and Kr{\"a}henb{\"u}hl, Philipp},
    booktitle={arXiv preprint arXiv:1904.07850},
    year={2019}
}

Author

LinkedIn

Medium

StackOverFlow

About

Object detection training and inference pipeline using CenterNet algorithm. Built on PyTorch Lightning and ONNX Runtime.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages