Skip to content

PyTorch implementation for Augmenting Efficient Real-time Surgical Instrument Segmentation in Video with Point Tracking and Segment Anything

Notifications You must be signed in to change notification settings

zijianwu1231/SIS-PT-SAM

Repository files navigation

Augmenting Efficient Surgical Instrument Segmentation in Video with Point Tracking and Segment Anything (SIS-PT-SAM)

News

The work was accepted by MICCAI 2024 AE-CAI Workshop!

PyTorch implementation of the SIS-PT-SAM. Inference speed achieves 25+/80+ FPS on single RTX 4060/4090 GPU. Use point prompts for full fine-tuning MobileSAM.

[arXiv]

Getting Started

Dependencies

  • Python 3.11
  • torch 2.1.2
  • torchvision 0.16.2

Environment Setup

Create a Conda environment:

conda create --name sis-pt-sam python=3.11

Activate the Conda environment and run the following command to install environment

pip install -r requirements.txt

Datasets

Links to the publicly available dataset used in this work:

Data Preparation

Please reformat each dataset according to the following top-level directory layout.

.
├── ...
├── train                  
│   ├── imgs 
│   │   ├──000000.png   
│   │   ├──000001.png
│   │   └──...                 
│   └── gts
│       ├──000000.png
│       ├──000001.png
│       └──...                
├── val
│    ├── imgs
│    │   ├──000000.png
│    │   ├──000001.png
│    │   └──... 
│    └── gts
│        ├──000000.png
│        ├──000001.png
│        └──...
└──...

Training

Download checkpoints of MobileSAM, CoTracker, and Light HQ-SAM. Put them into ./ckpts

If use single GPU for training, run:

python train.py -i ./data/[dataset]/train/ -v ./data/[dataset]/val/ --sam-ckpt ./ckpt/mobile_sam.pt --work-dir [path of the training results] --max-epochs 100 --data-aug --freeze-prompt-encoder --batch-size 4 --learn-rate 1e-5 --dataset [dataset]

For example:

python train.py -i /data/CholecSeg8k/train/ -v /data/CholecSeg8k/val/ --train-from-scratch --work-dir ./results/exp_cholecseg8k --max-epochs 100 --data-aug --freeze-prompt-encoder --batch-size 4 --learn-rate 1e-5 --dataset cholecseg

If use multi GPU for training, just add --multi-gpu and replace the device_ids in line 82 of the train.py as the GPU you would like to use:

if args.multi_gpu:
    surgicaltool_sam = nn.DataParallel(surgicaltool_sam, device_ids=[0,1,2,3])

Run Online Demo

We need to prepare the first frame and the corresponding mask of the video. If there are more than one tool to segment, please put masks of each tool into a folder.

Then use online_demo.py to run the online demo for a video.

python online_demo.py --video_path [video path] --tracker cotracker --sam_type finetune --tool_number 2 --first_frame_path [path of the first frame of the video] --first_mask_path [path of the first frame mask of the video] --mask_dir_path [folder that contains the mask of each tool in first frame] --save_demo --mode kmedoids --add_support_grid --sam-ckpt ./ckpts/[checkpoint file]

Contact

If you have any problem using this code then create an issue in this repository or contact me at zijianwu@ece.ubc.ca

License

This project is licensed under the MIT License

Acknowledgments

Thanks to the following awesome work for the inspiration, code snippets, etc.

About

PyTorch implementation for Augmenting Efficient Real-time Surgical Instrument Segmentation in Video with Point Tracking and Segment Anything

Topics

Resources

Stars

Watchers

Forks

Languages