From 1abf49e13d14ba7de4c2645b6db7327623a26dc8 Mon Sep 17 00:00:00 2001 From: John Bradley Date: Thu, 22 Aug 2024 10:36:38 -0400 Subject: [PATCH] Add --cls-file to predict command Fixes #30 --- README.md | 6 ++++-- src/bioclip/__main__.py | 16 ++++++++++++++-- tests/test_main.py | 13 ++++++++++++- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0c23cdd..b898112 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ bear 1.0 ## Command Line Usage ``` -bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...] +bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS | --cls-file CLS_FILE] [--device DEVICE] image_file [image_file ...] bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...] Commands: @@ -119,7 +119,9 @@ Options: --format=FORMAT format of the output (table or csv) for predict mode [default: csv] --rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species] --k=K number of top predictions to show [default: 5] - --cls=CLS comma separated list of classes to predict, when specified the --rank and --k arguments are not allowed + --cls=CLS comma separated list of classes to predict, when specified the --rank and --k + arguments are not allowed + --cls-file CLS_FILE path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed --device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu] --output=OUTFILE print output to file OUTFILE [default: stdout] ``` diff --git a/src/bioclip/__main__.py b/src/bioclip/__main__.py index 1361fe4..48dd82a 100644 --- a/src/bioclip/__main__.py +++ b/src/bioclip/__main__.py @@ -83,7 +83,9 @@ def create_parser(): predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'], help='rank of the classification, default: species (when)') predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5') - predict_parser.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed') + cls_group = predict_parser.add_mutually_exclusive_group(required=False) + cls_group.add_argument('--cls', help='comma separated list of classes to predict, when specified the --rank argument is not allowed') + cls_group.add_argument('--cls-file', help='path to file with list of classes to predict, one per line, when specified the --rank and --k arguments are not allowed') predict_parser.add_argument('--device', **device_arg) predict_parser.add_argument('--model', **model_arg) predict_parser.add_argument('--pretrained', **pretrained_arg) @@ -128,6 +130,13 @@ def parse_args(input_args=None): return args +def create_classes_str(cls_file_path): + """Reads a file with one class per line and returns a comma separated string of classes""" + with open(cls_file_path, 'r') as cls_file: + cls_str = [item.strip() for item in cls_file.readlines()] + return ",".join(cls_str) + + def main(): args = parse_args() if args.command == 'embed': @@ -137,10 +146,13 @@ def main(): model_str=args.model, pretrained_str=args.pretrained) elif args.command == 'predict': + cls_str = args.cls + if args.cls_file: + cls_str = create_classes_str(args.cls_file) predict(args.image_file, format=args.format, output=args.output, - cls_str=args.cls, + cls_str=cls_str, rank=args.rank, k=args.k, device=args.device, diff --git a/tests/test_main.py b/tests/test_main.py index 9738640..720a019 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,7 @@ import unittest -from bioclip.__main__ import parse_args, Rank +from unittest.mock import mock_open, patch +import argparse +from bioclip.__main__ import parse_args, Rank, create_classes_str class TestParser(unittest.TestCase): @@ -49,6 +51,10 @@ def test_parse_args(self): args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10']) self.assertEqual(args.k, 10) + args = parse_args(['predict', '--cls-file', 'somefile.txt', 'image.jpg']) + self.assertEqual(args.cls_file, 'somefile.txt') + self.assertEqual(args.cls, None) + args = parse_args(['embed', 'image.jpg']) self.assertEqual(args.command, 'embed') self.assertEqual(args.image_file, ['image.jpg']) @@ -60,3 +66,8 @@ def test_parse_args(self): self.assertEqual(args.image_file, ['image.jpg', 'image2.png']) self.assertEqual(args.output, 'data.json') self.assertEqual(args.device, 'cuda') + + def test_create_classes_str(self): + data = "class1\nclass2\nclass3" + with patch("builtins.open", mock_open(read_data=data)) as mock_file: + self.assertEqual(create_classes_str('path/to/file'), 'class1,class2,class3')