Skip to content

Commit

Permalink
bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
Pan Junting committed Jul 2, 2019
1 parent 20d56ca commit bae9b10
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 38 deletions.
30 changes: 15 additions & 15 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@


def get_training_set(opt):

assert opt.datset in ['cityscapes', 'cityscapes_two_path', 'kth']
assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth']

if opt.dataset == 'cityscapes':
from datasets.cityscapes_dataset_w_mask import Cityscapes

train_Dataset = Cityscapes(datapath=CITYSCAPES_TRAIN_DATA_PATH, datalist=CITYSCAPES_TRAIN_DATA_LIST,
size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames)
train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_VAL_DATA_LIST,
size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames,
mask_suffix='ssmask.png', returnpath=False)

elif opt.dataset == 'cityscapes_two_path':
from datasets.cityscapes_dataset_w_mask_two_path import Cityscpes
train_Dataset = Cityscapes(datapath=CITYSCAPES_TRAIN_DATA_PATH,
mask_data_path=CITYSCAPES_TRAIN_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_TRAIN_DATA_LIST,
from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes
train_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_VAL_DATA_LIST,
size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames,
mask_suffix='ssmask.png')
mask_suffix='ssmask.png', returnpath=False)

elif opt.dataset == 'kth':

Expand All @@ -30,29 +30,28 @@ def get_training_set(opt):


def get_test_set(opt):

assert opt.dataset in ['cityscapes', 'cityscapes_two_path', 'kth', 'ucf101', 'KITTI']

if opt.dataset == 'cityscapes':
from datasets.cityscapes_dataset_w_mask import Cityscapes
test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_VAL_DATA_LIST,
size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames,
size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames,
mask_suffix='ssmask.png', returnpath=True)

elif opt.dataset == 'cityscapes_two_path':
from datasets.cityscapes_dataset_w_mask_two_path import Cityscapes
test_Dataset = Cityscapes(datapath=CITYSCAPES_VAL_DATA_PATH, mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_VAL_DATA_LIST,
size=opt.input_size, split='train', split_num=1, num_frames=opt.num_frames,
size=opt.input_size, split='test', split_num=1, num_frames=opt.num_frames,
mask_suffix='ssmask.png', returnpath=True)

elif opt.dataset == 'cityscapes_pix2pixHD':
from cityscapes_dataloader_w_mask_pix2pixHD import Cityscapes
test_Dataset = Cityscapes(datapath=CITYSCAPES_TEST_DATA_PATH,
mask_data_path=CITYSCAPES_VAL_DATA_SEGMASK_PATH,
datalist=CITYSCAPES_VAL_DATA_MASK_LIST,
size= opt.input_size, split='test', split_num=1,
size=opt.input_size, split='test', split_num=1,
num_frames=opt.num_frames, mask_suffix='ssmask.png', returnpath=True)

elif opt.dataset == 'kth':
Expand All @@ -70,6 +69,7 @@ def get_test_set(opt):
elif opt.dataset == 'ucf101':
from datasets.ucf101_dataset import UCF101
test_Dataset = UCF101(datapath=os.path.join(UCF_101_DATA_PATH, category),
datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())), returnpath=True)
datalist=os.path.join(UCF_101_DATA_PATH, 'list/test%s.txt' % (opt.category.lower())),
returnpath=True)

return test_Dataset
return test_Dataset
5 changes: 5 additions & 0 deletions src/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def parse_opts():
default=3,
type=int,
help='input image channel (3 for RGB, 1 for Grayscale)')
parser.add_argument(
'--alpha_recon_image',
default=0.85,
type=float,
help='weight of reconstruction loss.')
parser.add_argument(
'--input_size',
default=(128, 256),
Expand Down
2 changes: 1 addition & 1 deletion src/test_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test(self):

if self.load:
# model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
model_name = '../pretrained_models/refine_genmask_098000.pth.tar'
model_name = '../pretrained_models/cityscapes/refine_genmask_098000.pth.tar'

print ("loading model from {}".format(model_name))

Expand Down
2 changes: 1 addition & 1 deletion src/test_refine_w_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test(self):

if self.load:
# model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
model_name = '../pretrained_models/refine_genmask_w_mask_098000.pth.tar'
model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_098000.pth.tar'

print ("loading model from {}".format(model_name))

Expand Down
3 changes: 2 additions & 1 deletion src/test_refine_w_mask_two_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
args = parse_opts()
print (args)


def make_save_dir(output_image_dir):
val_cities = ['frankfurt', 'lindau', 'munster']
for city in val_cities:
Expand Down Expand Up @@ -92,7 +93,7 @@ def test(self):
print(self.jobname)

if self.load:
model_name = '../pretrained_models/refine_genmask_w_mask_two_path_096000.pth.tar'
model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar'
# model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)

print ("loading model from {}".format(model_name))
Expand Down
3 changes: 2 additions & 1 deletion src/test_refine_w_mask_two_path_iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test(self):
print(self.jobname)

if self.load:
model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
# model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar'

print ("loading model from {}".format(model_name))

Expand Down
13 changes: 6 additions & 7 deletions src/train_refine_multigpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
from torch.autograd import Variable as Vb
import torch.optim as optim
from torch.utils.data import DataLoader
import os, time, sys

from models.multiframe_genmask import *
from utils import utils
from uitls import ops
from utils import ops
import losses
from dataset import get_training_set, get_test_set
from opts import parse_opts
Expand All @@ -14,7 +15,6 @@
print (opt)



class flowgen(object):

def __init__(self, opt):
Expand Down Expand Up @@ -58,12 +58,11 @@ def train(self):

vae = VAE(hallucination=self.useHallucination, opt=opt).cuda()
if torch.cuda.device_count() > 1:
vae = nn.DataParallel(vae, opt.sync).cuda()
vae = nn.DataParallel(vae).cuda()

objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww)

print(self.jobname)
cudnn.benchmark = True

optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate)

Expand Down Expand Up @@ -113,7 +112,7 @@ def train(self):
mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature,
y_pred_before_refine=y_pred_before_refine)

loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) / world_size
loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss)

# backward
loss.backward()
Expand All @@ -123,7 +122,7 @@ def train(self):
end = time.time()

# print statistics
if iteration % 20 == 0 and rank == 0:
if iteration % 20 == 0:
print(
"iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, "
"recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, "
Expand All @@ -140,7 +139,7 @@ def train(self):
# Set to evaluation mode (randomly sample z from the whole distribution)
with torch.no_grad():
vae.eval()
val_sample = iter(self.testloader).next()
val_sample, _, _ = iter(self.testloader).next()

# Read data
data = val_sample.cuda()
Expand Down
12 changes: 6 additions & 6 deletions src/train_refine_multigpu_w_mask.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
from torch.autograd import Variable as Vb
import torch.optim as optim
from torch.utils.data import DataLoader
import os, time, sys

from models.multiframe_w_mask_genmask import *
from utils import utils
from uitls import ops
from utils import ops
import losses
from dataset import get_training_set, get_test_set
from opts import parse_opts
Expand Down Expand Up @@ -58,12 +59,11 @@ def train(self):

vae = VAE(hallucination=self.useHallucination, opt=opt).cuda()
if torch.cuda.device_count() > 1:
vae = nn.DataParallel(vae, opt.sync).cuda()
vae = nn.DataParallel(vae).cuda()

objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww)

print(self.jobname)
cudnn.benchmark = True

optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate)

Expand Down Expand Up @@ -114,7 +114,7 @@ def train(self):
mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature,
y_pred_before_refine=y_pred_before_refine)

loss = (flowloss + 2.*reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1*mask_loss)/ world_size
loss = (flowloss + 2.*reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1*mask_loss)

# backward
loss.backward()
Expand All @@ -125,7 +125,7 @@ def train(self):
end = time.time()

# print statistics
if iteration % 20 == 0 and rank == 0:
if iteration % 20 == 0:
print("iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, recon_loss_before = {:.3f}, "
"flow_loss = {:.6f}, flow_consist = {:.3f}, "
"kl_loss = {:.6f}, img_sim_loss= {:.3f}, vgg_loss= {:.3f}, mask_loss={:.3f}, time/batch = {:.3f}"
Expand All @@ -137,7 +137,7 @@ def train(self):
with torch.no_grad():
vae.eval()

val_sample, val_mask = iter(self.testloader).next()
val_sample, val_mask, _ = iter(self.testloader).next()

# Read data
data = val_sample.cuda()
Expand Down
12 changes: 6 additions & 6 deletions src/train_refine_multigpu_w_mask_two_path.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
from torch.autograd import Variable as Vb
import torch.optim as optim
from torch.utils.data import DataLoader
import os, time, sys

from models.multiframe_w_mask_genmask_two_path import *
from utils import utils
from uitls import ops
from utils import ops
import losses
from dataset import get_training_set, get_test_set
from opts import parse_opts
Expand Down Expand Up @@ -58,12 +59,11 @@ def train(self):

vae = VAE(hallucination=self.useHallucination, opt=opt).cuda()
if torch.cuda.device_count() > 1:
vae = nn.DataParallel(vae, opt.sync).cuda()
vae = nn.DataParallel(vae).cuda()

objective_func = losses.losses_multigpu_only_mask(opt, vae.module.floww)

print(self.jobname)
cudnn.benchmark = True

optimizer = optim.Adam(vae.parameters(), lr=opt.lr_rate)

Expand Down Expand Up @@ -118,7 +118,7 @@ def train(self):
mask_fw, mask_bw, prediction_vgg_feature, gt_vgg_feature,
y_pred_before_refine=y_pred_before_refine)

loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss) / world_size
loss = (flowloss + 2. * reconloss + reconloss_back + reconloss_before + kldloss * self.opt.lamda + flowcon + sim_loss + vgg_loss + 0.1 * mask_loss)

# backward
loss.backward()
Expand All @@ -128,7 +128,7 @@ def train(self):
end = time.time()

# print statistics
if iteration % 20 == 0 and rank == 0:
if iteration % 20 == 0:
print(
"iter {} (epoch {}), recon_loss = {:.6f}, recon_loss_back = {:.3f}, "
"recon_loss_before = {:.3f}, flow_loss = {:.6f}, flow_consist = {:.3f}, kl_loss = {:.6f}, "
Expand All @@ -145,7 +145,7 @@ def train(self):
# Set to evaluation mode (randomly sample z from the whole distribution)
with torch.no_grad():
vae.eval()
val_sample, val_bg_mask, val_fg_mask = iter(self.testloader).next()
val_sample, val_bg_mask, val_fg_mask, _ = iter(self.testloader).next()

# Read data
data = val_sample.cuda()
Expand Down

0 comments on commit bae9b10

Please sign in to comment.