Skip to content

Commit

Permalink
Update train and utils submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonjewik committed Nov 30, 2024
1 parent 302b3a8 commit fb6ba25
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
44 changes: 22 additions & 22 deletions src/phenocam_snow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def main():
parser.add_argument("--classes", nargs="+", help="The image classes to use.")
args = parser.parse_args()

label_method = "via subdir" # can't do "in notebook" from a script
if args.new and args.existing:
print("Cannot specify both --new and --existing")
elif args.new:
Expand All @@ -56,7 +55,6 @@ def main():
args.learning_rate,
args.weight_decay,
args.site_name,
label_method,
args.n_train,
args.n_test,
args.classes,
Expand All @@ -82,7 +80,6 @@ def train_model_with_new_data(
learning_rate: float,
weight_decay: float,
site_name: str,
label_method: str,
n_train: int,
n_test: int,
classes: list[str],
Expand All @@ -97,8 +94,6 @@ def train_model_with_new_data(
:type weight_decay: float
:param site_name: The name of the PhenoCam site you want.
:type site_name: str
:param label_method: How you wish to label images ("in notebook" or "via subdir").
:type label_method: str
:param n_train: The number of training images to use.
:type n_train: int
:param n_test: The number of testing images to use.
Expand All @@ -109,10 +104,6 @@ def train_model_with_new_data(
:return: The best model obtained during training.
:rtype: PhenoCamResNet
"""
valid_label_methods = ["in notebook", "via subdir"]
if label_method not in valid_label_methods:
raise ValueError(f"{label_method} is not a valid label method")

train_dir = f"{site_name}_train"
test_dir = f"{site_name}_test"
train_labels = f"{train_dir}/labels.csv"
Expand All @@ -121,20 +112,29 @@ def train_model_with_new_data(
data_module = PhenoCamDataModule(
site_name, train_dir, train_labels, test_dir, test_labels
)
base_download_args = {"site_name": site_name}
base_label_args = {
"site_name": site_name,
"categories": classes,
"method": label_method,
}
data_module.prepare_data(
train_download_args=base_download_args
+ {"save_to": train_dir, "n_photos": n_train},
train_label_args=base_label_args
+ {"img_dir": train_dir, "save_to": train_labels},
test_download_args=base_download_args
+ {"save_to": test_dir, "n_photos": n_test},
test_label_args=base_label_args + {"img_dir": test_dir, "save_to": test_labels},
train_download_args={
"site_name": site_name,
"save_to": train_dir,
"n_photos": n_train,
},
train_label_args={
"site_name": site_name,
"categories": classes,
"img_dir": train_dir,
"save_to": train_labels,
},
test_download_args={
"site_name": site_name,
"save_to": test_dir,
"n_photos": n_test,
},
test_label_args={
"site_name": site_name,
"categories": classes,
"img_dir": test_dir,
"save_to": test_labels,
},
)

return _fit_and_eval_model(
Expand Down
17 changes: 9 additions & 8 deletions src/phenocam_snow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def label_images(
"""
if type(img_dir) is not Path:
img_dir = Path(img_dir)
assert img_dir.is_dir()
if not img_dir.is_dir():
raise ValueError(f"img_dir {img_dir} is not a directory")

dircats = []
for cat in categories:
Expand All @@ -307,24 +308,24 @@ def label_images(
"Move images into the appropriate sub-directory then press any key to continue."
)

filenames = []
data = []
for dircat in dircats:
filenames.extend([os.path.basename(fpath) for fpath in dircat.glob("*.jpg")])
df = pd.DataFrame(
zip(filenames, categories), columns=["filename", "label"]
).explode("filename")
with open(save_to, "w+") as f:
for fpath in dircat.glob("*.jpg"):
data.append((os.path.basename(fpath), os.path.basename(dircat)))
df = pd.DataFrame(data, columns=["filename", "label"])
with open(img_dir / save_to, "w+") as f:
f.write(f"# Site: {site_name}\n")
f.write("# Categories:\n")
for i, cat in enumerate(categories):
f.write(f"# {i}. {cat}\n")
df.to_csv(save_to, mode="a", index=False)
df.to_csv(img_dir / save_to, mode="a", index=False)

for item in img_dir.glob("*"):
if item.is_dir():
for subitem in sorted(item.glob("*")):
new_path = Path(subitem.resolve().parent.parent).joinpath(subitem.name)
subitem.rename(new_path)
item.rmdir()


def read_labels(labels_file: str | Path) -> pd.DataFrame:
Expand Down

0 comments on commit fb6ba25

Please sign in to comment.