Skip to content

Commit

Permalink
fixed noise error.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pan Junting committed Jul 5, 2019
1 parent bae9b10 commit 34877bb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datasets.dataset_path import *

import os

def get_training_set(opt):
assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth']
Expand Down Expand Up @@ -67,8 +67,8 @@ def get_test_set(opt):
returnpath=True)

elif opt.dataset == 'ucf101':
from datasets.ucf101_dataset import UCF101
test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, category),
from datasets.ucf_dataset import UCF101
test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, opt.category),
datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())),
returnpath=True)

Expand Down
52 changes: 33 additions & 19 deletions src/test_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from opts import parse_opts

args = parse_opts()
print (args)
print(args)


def make_save_dir(output_image_dir):
Expand Down Expand Up @@ -50,8 +50,10 @@ def __init__(self, opt):

test_Dataset = get_test_set(opt)

self.sampledir = os.path.join('../city_scapes_test_results', self.jobname,
self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed))
# self.sampledir = os.path.join('../city_scapes_test_results', self.jobname,
# self.suffix + '_' + str(self.iter_to_load)+'_'+str(opt.seed))

self.sampledir = os.path.join('../ucf101_results', opt.category, self.suffix + '_' + str(opt.seed))

if not os.path.exists(self.sampledir):
os.makedirs(self.sampledir)
Expand Down Expand Up @@ -82,22 +84,31 @@ def test(self):
opt = self.opt

gpu_ids = range(torch.cuda.device_count())
print ('Number of GPUs in use {}'.format(gpu_ids))
print('Number of GPUs in use {}'.format(gpu_ids))

iteration = 0

if torch.cuda.device_count() > 1:
vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine), device_ids=gpu_ids).cuda()
vae = nn.DataParallel(VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine),
device_ids=gpu_ids).cuda()
else:
vae = VAE(hallucination=self.useHallucination, opt=opt, refine=self.refine).cuda()

print(self.jobname)

if self.load:
# model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar'

print ("loading model from {}".format(model_name))
if opt.dataset == 'cityscapes':
model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar'

elif opt.dataset == 'ucf101':
model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category.lower() + '_model.pth.tar'

else:
model_name = '../pretrained_models/' + opt.dataset + '/' + opt.category + '_16frames_model.pth.tar'

print("loading model from {}".format(model_name))

state_dict = torch.load(model_name)
if torch.cuda.device_count() > 1:
Expand All @@ -106,8 +117,7 @@ def test(self):
vae.load_state_dict(state_dict['vae'])

z_noise = torch.ones(1, 1024).normal_()
for sample,_, paths in tqdm(iter(self.testloader)):

for sample, _, paths in tqdm(iter(self.testloader)):
# Set to evaluation mode (randomly sample z from the whole distribution)
vae.eval()

Expand All @@ -119,29 +129,33 @@ def test(self):
# data = data.repeat(1, opt.num_frames, 1, 1, 1)

frame1 = data[:, 0, :, :, :]

noise_bg = Vb(torch.randn(frame1.size())).cuda()
z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1))

y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, noise_bg, z_m)
z_m = Vb(z_noise.repeat(frame1.size()[0] * 4 * int(frame1.shape[-1] / 128), 1))

utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir, opt,
eval=True, useMask=True)
y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(frame1, data, noise_bg,
z_m)

utils.save_samples(data, y_pred_before_refine, y_pred, flow, mask_fw, mask_bw, iteration, self.sampledir,
opt,
eval=True, useMask=True)

utils.save_images(self.output_image_dir, data, y_pred, paths, opt)
utils.save_images(self.output_image_before_dir, data, y_pred_before_refine, paths, opt)

data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy()
utils.save_gif(data * 255, opt.num_frames, [8, 4], self.sampledir + '/{:06d}_real.gif'.format(iteration))

utils.save_flows(self.output_fw_flow_dir, flow, paths)
utils.save_flows(self.output_bw_flow_dir, flowback, paths)

utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths)
utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths)
# utils.save_flows(self.output_fw_flow_dir, flow, paths)
# utils.save_flows(self.output_bw_flow_dir, flowback, paths)
#
# utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths)
# utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths)

iteration += 1


if __name__ == '__main__':
a = flowgen(opt=args)
a.test()
a.test()

0 comments on commit 34877bb

Please sign in to comment.