Skip to content

Commit

Permalink
Implement --include-subfolders argument in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Jun 30, 2019
1 parent 2c60838 commit 4e4c026
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The frame filenames should have zero-padded frame numbers, for example like this
If you have multiple sequences of frames (i.e. from different videos/scenes/shots), you can have different prefixes in the frame filenames, like this:
* firstvideo00001.png, firstvideo00002.png, firstvideo00003.png, ..., secondvideo00001.png, secondvideo00002.png, secondvideo00003.png, ...

Alternatively, the different frame sequences can reside in different subfolders. For that to work, you have to use the `--include-subfolders` argument.

## Apply video colorization to a folder of PNG frames

`python -m tcvc.apply --input-path /path/to/images/ --input-style line_art`
Expand Down
6 changes: 4 additions & 2 deletions tcvc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from tcvc.dataset import DatasetFromFolder


def get_dataset(root_dir, use_line_art=True):
return DatasetFromFolder(root_dir, use_line_art)
def get_dataset(root_dir, use_line_art=True, include_subfolders=False):
return DatasetFromFolder(
root_dir, use_line_art, include_subfolders=include_subfolders
)


def create_iterator(sample_size, sample_dataset):
Expand Down
1 change: 1 addition & 0 deletions tcvc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, image_dir, use_line_art=True, include_subfolders=False):
self.image_file_paths = get_image_file_paths(
image_dir, include_subfolders=include_subfolders
)
assert len(self.image_file_paths) > 0
transform_list = [ToTensor()]
self.transform = Compose(transform_list)

Expand Down
24 changes: 20 additions & 4 deletions tcvc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from tcvc.util import stitch_images, postprocess

if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ[
"CUDA_VISIBLE_DEVICES"
] = "0" # Ensure that we only use one GPU, not multiple

# Training settings
parser = argparse.ArgumentParser(
Expand All @@ -30,13 +32,19 @@
required=True,
help="Path to a folder that contains the training set (image frames)",
)
parser.add_argument(
"--include-subfolders",
dest="include_subfolders",
action="store_true",
help="Include images from subfolders in the specified dataset path.",
)
parser.add_argument(
"--input-style",
dest="input_style",
type=str,
choices=["line_art", "greyscale"],
help="line_art (canny edge detection) or greyscale",
default="line_art",
default="greyscale",
)
parser.add_argument("--logfile", required=False, default="training_logs.dat")
parser.add_argument("--checkpoint", required=False, help="load pre-trained?")
Expand Down Expand Up @@ -104,9 +112,17 @@
torch.cuda.manual_seed(opt.seed)

print("===> Loading datasets")
train_set = get_dataset(opt.dataset, use_line_art=opt.input_style == "line_art")
train_set = get_dataset(
opt.dataset,
use_line_art=opt.input_style == "line_art",
include_subfolders=opt.include_subfolders,
)
# TODO: Add a separate argument for test set path. Do not use the same paths for training and testing
test_set = get_dataset(opt.dataset, use_line_art=opt.input_style == "line_art")
test_set = get_dataset(
opt.dataset,
use_line_art=opt.input_style == "line_art",
include_subfolders=opt.include_subfolders,
)

training_data_loader = DataLoader(
dataset=train_set,
Expand Down

0 comments on commit 4e4c026

Please sign in to comment.