diff --git a/src/phenocam_snow/train.py b/src/phenocam_snow/train.py index 2b617cd..e58ae37 100644 --- a/src/phenocam_snow/train.py +++ b/src/phenocam_snow/train.py @@ -47,7 +47,6 @@ def main(): parser.add_argument("--classes", nargs="+", help="The image classes to use.") args = parser.parse_args() - label_method = "via subdir" # can't do "in notebook" from a script if args.new and args.existing: print("Cannot specify both --new and --existing") elif args.new: @@ -56,7 +55,6 @@ def main(): args.learning_rate, args.weight_decay, args.site_name, - label_method, args.n_train, args.n_test, args.classes, @@ -82,7 +80,6 @@ def train_model_with_new_data( learning_rate: float, weight_decay: float, site_name: str, - label_method: str, n_train: int, n_test: int, classes: list[str], @@ -97,8 +94,6 @@ def train_model_with_new_data( :type weight_decay: float :param site_name: The name of the PhenoCam site you want. :type site_name: str - :param label_method: How you wish to label images ("in notebook" or "via subdir"). - :type label_method: str :param n_train: The number of training images to use. :type n_train: int :param n_test: The number of testing images to use. @@ -109,10 +104,6 @@ def train_model_with_new_data( :return: The best model obtained during training. :rtype: PhenoCamResNet """ - valid_label_methods = ["in notebook", "via subdir"] - if label_method not in valid_label_methods: - raise ValueError(f"{label_method} is not a valid label method") - train_dir = f"{site_name}_train" test_dir = f"{site_name}_test" train_labels = f"{train_dir}/labels.csv" @@ -121,20 +112,29 @@ def train_model_with_new_data( data_module = PhenoCamDataModule( site_name, train_dir, train_labels, test_dir, test_labels ) - base_download_args = {"site_name": site_name} - base_label_args = { - "site_name": site_name, - "categories": classes, - "method": label_method, - } data_module.prepare_data( - train_download_args=base_download_args - + {"save_to": train_dir, "n_photos": n_train}, - train_label_args=base_label_args - + {"img_dir": train_dir, "save_to": train_labels}, - test_download_args=base_download_args - + {"save_to": test_dir, "n_photos": n_test}, - test_label_args=base_label_args + {"img_dir": test_dir, "save_to": test_labels}, + train_download_args={ + "site_name": site_name, + "save_to": train_dir, + "n_photos": n_train, + }, + train_label_args={ + "site_name": site_name, + "categories": classes, + "img_dir": train_dir, + "save_to": train_labels, + }, + test_download_args={ + "site_name": site_name, + "save_to": test_dir, + "n_photos": n_test, + }, + test_label_args={ + "site_name": site_name, + "categories": classes, + "img_dir": test_dir, + "save_to": test_labels, + }, ) return _fit_and_eval_model( diff --git a/src/phenocam_snow/utils.py b/src/phenocam_snow/utils.py index c93a43e..2f44ef4 100644 --- a/src/phenocam_snow/utils.py +++ b/src/phenocam_snow/utils.py @@ -293,7 +293,8 @@ def label_images( """ if type(img_dir) is not Path: img_dir = Path(img_dir) - assert img_dir.is_dir() + if not img_dir.is_dir(): + raise ValueError(f"img_dir {img_dir} is not a directory") dircats = [] for cat in categories: @@ -307,24 +308,24 @@ def label_images( "Move images into the appropriate sub-directory then press any key to continue." ) - filenames = [] + data = [] for dircat in dircats: - filenames.extend([os.path.basename(fpath) for fpath in dircat.glob("*.jpg")]) - df = pd.DataFrame( - zip(filenames, categories), columns=["filename", "label"] - ).explode("filename") - with open(save_to, "w+") as f: + for fpath in dircat.glob("*.jpg"): + data.append((os.path.basename(fpath), os.path.basename(dircat))) + df = pd.DataFrame(data, columns=["filename", "label"]) + with open(img_dir / save_to, "w+") as f: f.write(f"# Site: {site_name}\n") f.write("# Categories:\n") for i, cat in enumerate(categories): f.write(f"# {i}. {cat}\n") - df.to_csv(save_to, mode="a", index=False) + df.to_csv(img_dir / save_to, mode="a", index=False) for item in img_dir.glob("*"): if item.is_dir(): for subitem in sorted(item.glob("*")): new_path = Path(subitem.resolve().parent.parent).joinpath(subitem.name) subitem.rename(new_path) + item.rmdir() def read_labels(labels_file: str | Path) -> pd.DataFrame: