diff --git a/src/dataset.py b/src/dataset.py index baeafc4..878c86f 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,5 +1,5 @@ from datasets.dataset_path import * - +import os def get_training_set(opt): assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth'] @@ -67,8 +67,8 @@ def get_test_set(opt): returnpath=True) elif opt.dataset == 'ucf101': - from datasets.ucf101_dataset import UCF101 - test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, category), + from datasets.ucf_dataset import UCF101 + test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, opt.category), datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())), returnpath=True) diff --git a/src/test_refine.py b/src/test_refine.py index c1a75be..352564a 100644 --- a/src/test_refine.py +++ b/src/test_refine.py @@ -13,7 +13,7 @@ from opts import parse_opts args = parse_opts() -print (args) +print(args) def make_save_dir(output_image_dir): @@ -50,8 +50,10 @@ def __init__(self, opt): test_Dataset = get_test_set(opt) - self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, - self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)) + # self.sampledir = os.path.join('../city_scapes_test_results', self.jobname, + # self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed)) + + self.sampledir = os.path.join('../ucf101_results', opt.category, self.suffix + '_' + str(opt.seed)) if not os.path.exists(self.sampledir): os.makedirs(self.sampledir) @@ -82,12 +84,13 @@ def test(self): opt = self.opt gpu_ids = range(torch.cuda.device_count()) - print ('Number of GPUs in use {}'.format(gpu_ids)) + print('Number of GPUs in use {}'.format(gpu_ids)) iteration = 0 if torch.cuda.device_count() > 1: - vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine), device_ids=gpu_ids).cuda() + vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine), + device_ids=gpu_ids).cuda() else: vae = VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine).cuda() @@ -95,9 +98,17 @@ def test(self): if self.load: # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load) - model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar' - print ("loading model from {}".format(model_name)) + if opt.dataset == 'cityscapes': + model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar' + + elif opt.dataset == 'ucf101': + model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category.lower() + '_model.pth.tar' + + else: + model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category + '_16frames_model.pth.tar' + + print("loading model from {}".format(model_name)) state_dict = torch.load(model_name) if torch.cuda.device_count() > 1: @@ -106,8 +117,7 @@ def test(self): vae.load_state_dict(state_dict['vae']) z_noise = torch.ones(1, 1024).normal_() - for sample,_, paths in tqdm(iter(self.testloader)): - + for sample, _, paths in tqdm(iter(self.testloader)): # Set to evaluation mode (randomly sample z from the whole distribution) vae.eval() @@ -119,13 +129,17 @@ def test(self): # data = data.repeat(1, opt.num_frames, 1, 1, 1) frame1 = data[:, 0, :, :, :] + noise_bg = Vb(torch.randn(frame1.size())).cuda() - z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1)) - y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, noise_bg, z_m) + z_m = Vb(z_noise.repeat(frame1.size()[0] * 4 * int(frame1.shape[-1] / 128), 1)) - utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt, - eval=True, useMask=True) + y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, noise_bg, + z_m) + + utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, + opt, + eval=True, useMask=True) utils.save_images(self.output_image_dir, data, y_pred, paths, opt) utils.save_images(self.output_image_before_dir, data, y_pred_before_refine, paths, opt) @@ -133,15 +147,15 @@ def test(self): data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy() utils.save_gif(data * 255, opt.num_frames, [8, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration)) - utils.save_flows(self.output_fw_flow_dir, flow, paths) - utils.save_flows(self.output_bw_flow_dir, flowback, paths) - - utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) - utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) + # utils.save_flows(self.output_fw_flow_dir, flow, paths) + # utils.save_flows(self.output_bw_flow_dir, flowback, paths) + # + # utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths) + # utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths) iteration += 1 if __name__ == '__main__': a = flowgen(opt=args) - a.test() \ No newline at end of file + a.test()