diff --git a/src/dataset.py b/src/dataset.py index b23694f..baeafc4 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -2,22 +2,22 @@ def get_training_set(opt): - - assert opt.datset in ['cityscapes', 'cityscapes_two_path', 'kth'] + assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth'] if opt.dataset == 'cityscapes': from datasets.cityscapes_dataset_w_mask import Cityscapes - train_Dataset = Cityscapes(datapath=CITYSCAPES_TRAIN_DATA_PATH, datalist=CITYSCAPES_TRAIN_DATA_LIST, - size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames) + train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, + datalist=CITYSCAPES_VAL_DATA_LIST, + size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, + mask_suffix='ssmask.png', returnpath=False) elif opt.dataset == 'cityscapes_two_path': - from datasets.cityscapes_dataset_w_mask_two_path import Cityscpes - train_Dataset = Cityscapes(datapath=CITYSCAPES_TRAIN_DATA_PATH, - mask_data_path=CITYSCAPES_TRAIN_DATA_SEGMASK_PATH, - datalist=CITYSCAPES_TRAIN_DATA_LIST, + from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes + train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, + datalist=CITYSCAPES_VAL_DATA_LIST, size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, - mask_suffix='ssmask.png') + mask_suffix='ssmask.png', returnpath=False) elif opt.dataset == 'kth': @@ -30,21 +30,20 @@ def get_training_set(opt): def get_test_set(opt): - assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth', 'ucf101', 'KITTI'] if opt.dataset == 'cityscapes': from datasets.cityscapes_dataset_w_mask import Cityscapes test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, datalist=CITYSCAPES_VAL_DATA_LIST, - size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, + size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames, mask_suffix='ssmask.png', returnpath=True) elif opt.dataset == 'cityscapes_two_path': from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, datalist=CITYSCAPES_VAL_DATA_LIST, - size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames, + size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames, mask_suffix='ssmask.png', returnpath=True) elif opt.dataset == 'cityscapes_pix2pixHD': @@ -52,7 +51,7 @@ def get_test_set(opt): test_Dataset = Cityscapes(datapath=CITYSCAPES_TEST_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH, datalist=CITYSCAPES_VAL_DATA_MASK_LIST, - size= opt.input_size, split='test', split_num=1, + size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames, mask_suffix='ssmask.png', returnpath=True) elif opt.dataset == 'kth': @@ -70,6 +69,7 @@ def get_test_set(opt): elif opt.dataset == 'ucf101': from datasets.ucf101_dataset import UCF101 test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, category), - datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())), returnpath=True) + datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())), + returnpath=True) - return test_Dataset \ No newline at end of file + return test_Dataset diff --git a/src/opts.py b/src/opts.py index 677041d..7bf906d 100644 --- a/src/opts.py +++ b/src/opts.py @@ -13,6 +13,11 @@ def parse_opts(): default=3, type=int, help='input image channel (3 for RGB, 1 for Grayscale)') + parser.add_argument( + '--alpha_recon_image', + default=0.85, + type=float, + help='weight of reconstruction loss.') parser.add_argument( '--input_size', default=(128, 256), diff --git a/src/test_refine.py b/src/test_refine.py index 7403ff0..c1a75be 100644 --- a/src/test_refine.py +++ b/src/test_refine.py @@ -95,7 +95,7 @@ def test(self): if self.load: # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) - model_name = '../pretrained_models/refine_genmask_098000.pth.tar' + model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar' print ("loading model from {}".format(model_name)) diff --git a/src/test_refine_w_mask.py b/src/test_refine_w_mask.py index 22c52f1..a3cc965 100644 --- a/src/test_refine_w_mask.py +++ b/src/test_refine_w_mask.py @@ -92,7 +92,7 @@ def test(self): if self.load: # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) - model_name = '../pretrained_models/refine_genmask_w_mask_098000.pth.tar' + model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_098000.pth.tar' print ("loading model from {}".format(model_name)) diff --git a/src/test_refine_w_mask_two_path.py b/src/test_refine_w_mask_two_path.py index 75c4cde..e244477 100644 --- a/src/test_refine_w_mask_two_path.py +++ b/src/test_refine_w_mask_two_path.py @@ -15,6 +15,7 @@ args = parse_opts() print (args) + def make_save_dir(output_image_dir): val_cities = ['frankfurt', 'lindau', 'munster'] for city in val_cities: @@ -92,7 +93,7 @@ def test(self): print(self.jobname) if self.load: - model_name = '../pretrained_models/refine_genmask_w_mask_two_path_096000.pth.tar' + model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) print ("loading model from {}".format(model_name)) diff --git a/src/test_refine_w_mask_two_path_iterative.py b/src/test_refine_w_mask_two_path_iterative.py index 1d2556b..1af37b7 100644 --- a/src/test_refine_w_mask_two_path_iterative.py +++ b/src/test_refine_w_mask_two_path_iterative.py @@ -94,7 +94,8 @@ def test(self): print(self.jobname) if self.load: - model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) + # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) + model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar' print ("loading model from {}".format(model_name)) diff --git a/src/train_refine_multigpu.py b/src/train_refine_multigpu.py index 19eb25e..565cd49 100644 --- a/src/train_refine_multigpu.py +++ b/src/train_refine_multigpu.py @@ -1,11 +1,12 @@ import torch from torch.autograd import Variable as Vb import torch.optim as optim +from torch.utils.data import DataLoader import os, time, sys from models.multiframe_genmask import * from utils import utils -from uitls import ops +from utils import ops import losses from dataset import get_training_set, get_test_set from opts import parse_opts @@ -14,7 +15,6 @@ print (opt) - class flowgen(object): def __init__(self, opt): @@ -58,12 +58,11 @@ def train(self): vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() if torch.cuda.device_count() > 1: - vae = nn.DataParallel(vae, opt.sync).cuda() + vae = nn.DataParallel(vae).cuda() objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) print(self.jobname) - cudnn.benchmark = True optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) @@ -113,7 +112,7 @@ def train(self): mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, y_pred_before_refine=y_pred_before_refine) - loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) / world_size + loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) # backward loss.backward() @@ -123,7 +122,7 @@ def train(self): end = time.time() # print statistics - if iteration % 20 == 0 and rank == 0: + if iteration % 20 == 0: print( "iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, " "recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, " @@ -140,7 +139,7 @@ def train(self): # Set to evaluation mode (randomly sample z from the whole distribution) with torch.no_grad(): vae.eval() - val_sample = iter(self.testloader).next() + val_sample, _, _ = iter(self.testloader).next() # Read data data = val_sample.cuda() diff --git a/src/train_refine_multigpu_w_mask.py b/src/train_refine_multigpu_w_mask.py index 4a7e293..f38b17f 100644 --- a/src/train_refine_multigpu_w_mask.py +++ b/src/train_refine_multigpu_w_mask.py @@ -1,11 +1,12 @@ import torch from torch.autograd import Variable as Vb import torch.optim as optim +from torch.utils.data import DataLoader import os, time, sys from models.multiframe_w_mask_genmask import * from utils import utils -from uitls import ops +from utils import ops import losses from dataset import get_training_set, get_test_set from opts import parse_opts @@ -58,12 +59,11 @@ def train(self): vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() if torch.cuda.device_count() > 1: - vae = nn.DataParallel(vae, opt.sync).cuda() + vae = nn.DataParallel(vae).cuda() objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) print(self.jobname) - cudnn.benchmark = True optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) @@ -114,7 +114,7 @@ def train(self): mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, y_pred_before_refine=y_pred_before_refine) - loss = (flowloss + 2.*reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1*mask_loss)/ world_size + loss = (flowloss + 2.*reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1*mask_loss) # backward loss.backward() @@ -125,7 +125,7 @@ def train(self): end = time.time() # print statistics - if iteration % 20 == 0 and rank == 0: + if iteration % 20 == 0: print("iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, recon_loss_before = {:.3f}, " "flow_loss = {:.6f}, flow_consist = {:.3f}, " "kl_loss = {:.6f}, img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}" @@ -137,7 +137,7 @@ def train(self): with torch.no_grad(): vae.eval() - val_sample, val_mask = iter(self.testloader).next() + val_sample, val_mask, _ = iter(self.testloader).next() # Read data data = val_sample.cuda() diff --git a/src/train_refine_multigpu_w_mask_two_path.py b/src/train_refine_multigpu_w_mask_two_path.py index 1837f96..41b8344 100644 --- a/src/train_refine_multigpu_w_mask_two_path.py +++ b/src/train_refine_multigpu_w_mask_two_path.py @@ -1,11 +1,12 @@ import torch from torch.autograd import Variable as Vb import torch.optim as optim +from torch.utils.data import DataLoader import os, time, sys from models.multiframe_w_mask_genmask_two_path import * from utils import utils -from uitls import ops +from utils import ops import losses from dataset import get_training_set, get_test_set from opts import parse_opts @@ -58,12 +59,11 @@ def train(self): vae = VAE(hallucination=self.useHallucination, opt=opt).cuda() if torch.cuda.device_count() > 1: - vae = nn.DataParallel(vae, opt.sync).cuda() + vae = nn.DataParallel(vae).cuda() objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww) print(self.jobname) - cudnn.benchmark = True optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate) @@ -118,7 +118,7 @@ def train(self): mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature, y_pred_before_refine=y_pred_before_refine) - loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) / world_size + loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) # backward loss.backward() @@ -128,7 +128,7 @@ def train(self): end = time.time() # print statistics - if iteration % 20 == 0 and rank == 0: + if iteration % 20 == 0: print( "iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, " "recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, " @@ -145,7 +145,7 @@ def train(self): # Set to evaluation mode (randomly sample z from the whole distribution) with torch.no_grad(): vae.eval() - val_sample, val_bg_mask, val_fg_mask = iter(self.testloader).next() + val_sample, val_bg_mask, val_fg_mask, _ = iter(self.testloader).next() # Read data data = val_sample.cuda()