Skip to content

Commit

Permalink
code and samples
Browse files Browse the repository at this point in the history
  • Loading branch information
apple2373 committed Apr 25, 2019
1 parent eede1a5 commit 2f2b3a3
Show file tree
Hide file tree
Showing 33 changed files with 4,341 additions and 0 deletions.
84 changes: 84 additions & 0 deletions dataloaders/.ipynb_checkpoints/ImageListDataset-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from PIL import Image
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

class ImageListDataset(Dataset):
def __init__(self, path_label_list, img_root=None,
transform=None,
target_transform=None,label_exist = True,
loader=pil_loader):
self.img_root = img_root
self.data = path_label_list
self.label_exist = label_exist
if self.label_exist==False:
self.data = [ [item] for item in path_label_list]

self.transform = transform
self.target_transform = target_transform
self.loader = loader

def __getitem__(self, i):
'''
if label exists, get (img,label_idx) pair of i-th data point
if label does not exit, just return image tensor of i-th data point
img is already preprocessed
label_idx start from 0 incrementally so can be used for cnn input directly
'''
if self.label_exist:
return self.get_img(i), self.get_label_idx(i)
else:
return self.get_img(i)

def get_img_path(self,i):
'''
get img_path of i-th data point
'''
img_path = self.data[i][0]
if self.img_root is not None:
img_path = os.path.join(self.img_root, img_path)
return img_path

def get_img(self,i):
'''
get img array of i-th data point
self.transform is applied if exists
'''
img = self.loader(self.get_img_path(i))
if self.transform is not None:
img = self.transform(img)
return img

def get_label(self,i):
'''
get label of i-th data point as it is.
'''
assert self.label_exist
return self.data[i][1]

def get_label_idx(self,i):
'''
get label idx, which start from 0 incrementally
self.target_transform is applied if exists
'''
label = self.get_label(i)
if self.target_transform is not None:
if isinstance(self.target_transform, dict):
label_idx = self.target_transform[label]
else:
label_idx = self.target_transform(label)
else:
label_idx = int(label)
return label_idx

def __len__(self):
return len(self.data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import glob
from .ImageListDataset import ImageListDataset
from torchvision import transforms
from torch.utils.data import DataLoader

def setup_dataloader(name,h=128,w=128,batch_size=4,num_workers=4):
'''
instead of setting up dataloader that read raw image from file,
let's use store all images on cpu memmory
because this is for small dataset
'''
if name == "face":
img_path_list = glob.glob("./data/face/*.png")
elif name=="anime":
img_path_list = glob.glob("./data/anime/*.png")
else:
raise NotImplementedError("Unknown dataset %s"%name)

assert len(img_path_list) > 0

transform = transforms.Compose([
transforms.Resize( min(h,w) ),
transforms.CenterCrop( (h,w) ),
transforms.ToTensor(),
])

img_path_list = [[path,i] for i,path in enumerate(sorted(img_path_list))]
dataset = ImageListDataset(img_path_list,transform=transform)

return DataLoader([data for data in dataset],batch_size=batch_size,
shuffle=True,num_workers=num_workers)
84 changes: 84 additions & 0 deletions dataloaders/ImageListDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from PIL import Image
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

class ImageListDataset(Dataset):
def __init__(self, path_label_list, img_root=None,
transform=None,
target_transform=None,label_exist = True,
loader=pil_loader):
self.img_root = img_root
self.data = path_label_list
self.label_exist = label_exist
if self.label_exist==False:
self.data = [ [item] for item in path_label_list]

self.transform = transform
self.target_transform = target_transform
self.loader = loader

def __getitem__(self, i):
'''
if label exists, get (img,label_idx) pair of i-th data point
if label does not exit, just return image tensor of i-th data point
img is already preprocessed
label_idx start from 0 incrementally so can be used for cnn input directly
'''
if self.label_exist:
return self.get_img(i), self.get_label_idx(i)
else:
return self.get_img(i)

def get_img_path(self,i):
'''
get img_path of i-th data point
'''
img_path = self.data[i][0]
if self.img_root is not None:
img_path = os.path.join(self.img_root, img_path)
return img_path

def get_img(self,i):
'''
get img array of i-th data point
self.transform is applied if exists
'''
img = self.loader(self.get_img_path(i))
if self.transform is not None:
img = self.transform(img)
return img

def get_label(self,i):
'''
get label of i-th data point as it is.
'''
assert self.label_exist
return self.data[i][1]

def get_label_idx(self,i):
'''
get label idx, which start from 0 incrementally
self.target_transform is applied if exists
'''
label = self.get_label(i)
if self.target_transform is not None:
if isinstance(self.target_transform, dict):
label_idx = self.target_transform[label]
else:
label_idx = self.target_transform(label)
else:
label_idx = int(label)
return label_idx

def __len__(self):
return len(self.data)
31 changes: 31 additions & 0 deletions dataloaders/setup_dataloader_smallgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import glob
from .ImageListDataset import ImageListDataset
from torchvision import transforms
from torch.utils.data import DataLoader

def setup_dataloader(name,h=128,w=128,batch_size=4,num_workers=4):
'''
instead of setting up dataloader that read raw image from file,
let's use store all images on cpu memmory
because this is for small dataset
'''
if name == "face":
img_path_list = glob.glob("./data/face/*.png")
elif name=="anime":
img_path_list = glob.glob("./data/anime/*.png")
else:
raise NotImplementedError("Unknown dataset %s"%name)

assert len(img_path_list) > 0

transform = transforms.Compose([
transforms.Resize( min(h,w) ),
transforms.CenterCrop( (h,w) ),
transforms.ToTensor(),
])

img_path_list = [[path,i] for i,path in enumerate(sorted(img_path_list))]
dataset = ImageListDataset(img_path_list,transform=transform)

return DataLoader([data for data in dataset],batch_size=batch_size,
shuffle=True,num_workers=num_workers)
299 changes: 299 additions & 0 deletions ipython/.ipynb_checkpoints/example-checkpoint.ipynb

Large diffs are not rendered by default.

299 changes: 299 additions & 0 deletions ipython/example.ipynb

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions loss/.ipynb_checkpoints/AdaBIGGANLoss-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from .Vgg16PerceptualLoss import Vgg16PerceptualLoss

class AdaBIGGANLoss(nn.Module):
def __init__(self,perceptual_loss = "vgg",
scale_per=0.001,
scale_emd=0.1,
scale_reg=0.02,
normalize_img = True,
normalize_per = False,
dist_per = "l1",
):
'''
perceptual_loss: preceptural loss
perceptual_facter:
'''
super(AdaBIGGANLoss,self).__init__()
if perceptual_loss == "vgg":
self.perceptual_loss = Vgg16PerceptualLoss(loss_func=dist_per)
else:
self.perceptual_loss = perceptual_loss
self.scale_per = scale_per
self.scale_emd = scale_emd
self.scale_reg = scale_reg
self.normalize_img = normalize_img
self.normalize_perceptural = normalize_per

def earth_mover_dist(self,z):
"""
taken from https://github.com/nogu-atsu/SmallGAN/blob/f604cd17516963d8eec292f3faddd70c227b609a/gen_models/ada_generator.py#L150-L162
earth mover distance between z and standard normal distribution
"""
dim_z = z.shape[1]
n = z.shape[0]#batchsize
t = torch.randn((n * 10,dim_z),device=z.device)
dot = torch.matmul(z, t.permute(-1, -2))

#in the original implementation transb=True
#so we want to do t = t.swapaxes(-1, -2)
#from https://github.com/chainer/chainer/blob/c2cf7fb9c49cf98a94caf453f644d612ace45625/chainer/functions/math/matmul.py#L
#then swapaxes is .permute
#from https://discuss.pytorch.org/t/swap-axes-in-pytorch/970

dist = torch.sum(z ** 2, dim=1, keepdim=True) - 2 * dot + torch.sum(t ** 2, dim=1)

return torch.mean(dist.min(dim=0)[0]) + torch.mean(dist.min(dim=1)[0])

def l1_reg(self,W):
#https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L146-L148
#NOTE: I think this should be implemented as weight decay in the optimizer. It's not beatiful code to pass W into loss function.
return torch.mean( W ** 2 )

def forward(self,x,y,z,W):
#from IPython import embed;embed()
'''
x:generated image. shape is (batch,channel,h,w)
y:target image. shape is (batch,channel,h,w)
z: seed image embeddings (BEFORE adding the noise of eps). shape is (batch,embedding_dim)
W: model.linear.weight
see the equation (3) in the paper
'''

# F.mse_loss is L2 loss
# F.l1_loss is L1 loss

#pytorch regards an image as a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
#(see transforms.ToTensor() for details)
#but the model output uses tanh so x is ranging (-1 to 1)
#so let's rescale y to (-1 to 1) from (0 to 1)
#chainer implementation use (-1,1) for loss computation, so i didn't do the other way around (i.e. scale x to (0,1))
image_loss = F.l1_loss(x, 2.0*(y - 0.5) )
if self.normalize_img:
loss = image_loss/image_loss.item()
else:
loss = image_loss
#rescaled to 1 in the chainer code
#see https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L68-L69

for ploss in self.perceptual_loss(img1=x,img2=y,img1_minmax=(-1,1),img2_minmax=(0,1)):
if self.normalize_perceptural:
loss += self.scale_per*ploss/ploss.item()
else:
loss += self.scale_per*ploss

loss += self.scale_emd*self.earth_mover_dist(z)

loss += self.scale_reg*self.l1_reg(W)

return loss
Loading

0 comments on commit 2f2b3a3

Please sign in to comment.