Skip to content

Commit

Permalink
Add --cls-file to predict command
Browse files Browse the repository at this point in the history
Fixes #30
  • Loading branch information
johnbradley committed Aug 22, 2024
1 parent 452c5c6 commit 1abf49e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ bear 1.0

## Command Line Usage
```
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS | --cls-file CLS_FILE] [--device DEVICE] image_file [image_file ...]
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
Commands:
Expand All @@ -119,7 +119,9 @@ Options:
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
--k=K number of top predictions to show [default: 5]
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed
--cls=CLS comma separated list of classes to predict, when specified the --rank and --k
arguments are not allowed
--cls-file CLS_FILE path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
--output=OUTFILE print output to file OUTFILE [default: stdout]
```
Expand Down
16 changes: 14 additions & 2 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def create_parser():
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
help='rank of the classification, default: species (when)')
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
cls_group.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed')
cls_group.add_argument('--cls-file', help='path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed')
predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
predict_parser.add_argument('--pretrained', **pretrained_arg)
Expand Down Expand Up @@ -128,6 +130,13 @@ def parse_args(input_args=None):
return args


def create_classes_str(cls_file_path):
"""Reads a file with one class per line and returns a comma separated string of classes"""
with open(cls_file_path, 'r') as cls_file:
cls_str = [item.strip() for item in cls_file.readlines()]
return ",".join(cls_str)


def main():
args = parse_args()
if args.command == 'embed':
Expand All @@ -137,10 +146,13 @@ def main():
model_str=args.model,
pretrained_str=args.pretrained)
elif args.command == 'predict':
cls_str = args.cls
if args.cls_file:
cls_str = create_classes_str(args.cls_file)
predict(args.image_file,
format=args.format,
output=args.output,
cls_str=args.cls,
cls_str=cls_str,
rank=args.rank,
k=args.k,
device=args.device,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
from bioclip.__main__ import parse_args, Rank
from unittest.mock import mock_open, patch
import argparse
from bioclip.__main__ import parse_args, Rank, create_classes_str


class TestParser(unittest.TestCase):
Expand Down Expand Up @@ -49,6 +51,10 @@ def test_parse_args(self):
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)

args = parse_args(['predict', '--cls-file', 'somefile.txt', 'image.jpg'])
self.assertEqual(args.cls_file, 'somefile.txt')
self.assertEqual(args.cls, None)

args = parse_args(['embed', 'image.jpg'])
self.assertEqual(args.command, 'embed')
self.assertEqual(args.image_file, ['image.jpg'])
Expand All @@ -60,3 +66,8 @@ def test_parse_args(self):
self.assertEqual(args.image_file, ['image.jpg', 'image2.png'])
self.assertEqual(args.output, 'data.json')
self.assertEqual(args.device, 'cuda')

def test_create_classes_str(self):
data = "class1\nclass2\nclass3"
with patch("builtins.open", mock_open(read_data=data)) as mock_file:
self.assertEqual(create_classes_str('path/to/file'), 'class1,class2,class3')

0 comments on commit 1abf49e

Please sign in to comment.