-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
52 lines (41 loc) · 2.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
import logging
import torch
from train import train_model
def get_args():
parser = argparse.ArgumentParser(description='Paramaters for model training')
parser.add_argument('--n_epochs', type=int, help='Number of epochs')
parser.add_argument('--batch_size', type=int, help='Number of images in batch')
parser.add_argument('--ref_images_dir', type=str, help='Path to directory where ref images will be saved')
parser.add_argument('--download_datasets', type=str, help='Download dataset from Torchvision repo or use already existing dataset')
parser.add_argument('--root_datasets_dir', type=str, help='Path where dataset should be downloaded or where is it already stored')
parser.add_argument('--class_name', type=str, help='One of ten classes in CIFAR10 dataset')
parser.add_argument('--latent_vector_length', type=int, help='Length of random vector which will be transformed into an image by generator')
parser.add_argument('--init_generator_weights', type=str, help="Init generator's weights using normal distribiution")
parser.add_argument('--init_discriminator_weights', type=str, help="Init discriminator's weights using normal distribiution")
args = vars(parser.parse_args())
# parse str to boolean
str_true = ["Y", "y", "Yes", "yes", "true", "True"]
bool_params = ["download_datasets", "init_generator_weights", "init_discriminator_weights"]
for param in bool_params:
if args[param] in str_true:
args[param] = True
else:
args[param] = False
# log input parameters
logging.info(8*"-")
logging.info("PARAMETERS")
logging.info(8*"-")
for parameter in args.keys():
logging.info(f"{parameter}: {args[parameter]}")
logging.info(8*"-")
return args
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
args = get_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
logging.info(f"Device: {device}")
model = train_model(device, args["n_epochs"], args["batch_size"], args["ref_images_dir"],
args["download_datasets"], args["root_datasets_dir"], args["class_name"],
args["latent_vector_length"], args["init_generator_weights"],
args["init_discriminator_weights"])