-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
4,341 additions
and
0 deletions.
There are no files selected for viewing
84 changes: 84 additions & 0 deletions
84
dataloaders/.ipynb_checkpoints/ImageListDataset-checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
import os | ||
|
||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import transforms | ||
|
||
from PIL import Image | ||
def pil_loader(path): | ||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('RGB') | ||
|
||
class ImageListDataset(Dataset): | ||
def __init__(self, path_label_list, img_root=None, | ||
transform=None, | ||
target_transform=None,label_exist = True, | ||
loader=pil_loader): | ||
self.img_root = img_root | ||
self.data = path_label_list | ||
self.label_exist = label_exist | ||
if self.label_exist==False: | ||
self.data = [ [item] for item in path_label_list] | ||
|
||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.loader = loader | ||
|
||
def __getitem__(self, i): | ||
''' | ||
if label exists, get (img,label_idx) pair of i-th data point | ||
if label does not exit, just return image tensor of i-th data point | ||
img is already preprocessed | ||
label_idx start from 0 incrementally so can be used for cnn input directly | ||
''' | ||
if self.label_exist: | ||
return self.get_img(i), self.get_label_idx(i) | ||
else: | ||
return self.get_img(i) | ||
|
||
def get_img_path(self,i): | ||
''' | ||
get img_path of i-th data point | ||
''' | ||
img_path = self.data[i][0] | ||
if self.img_root is not None: | ||
img_path = os.path.join(self.img_root, img_path) | ||
return img_path | ||
|
||
def get_img(self,i): | ||
''' | ||
get img array of i-th data point | ||
self.transform is applied if exists | ||
''' | ||
img = self.loader(self.get_img_path(i)) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
return img | ||
|
||
def get_label(self,i): | ||
''' | ||
get label of i-th data point as it is. | ||
''' | ||
assert self.label_exist | ||
return self.data[i][1] | ||
|
||
def get_label_idx(self,i): | ||
''' | ||
get label idx, which start from 0 incrementally | ||
self.target_transform is applied if exists | ||
''' | ||
label = self.get_label(i) | ||
if self.target_transform is not None: | ||
if isinstance(self.target_transform, dict): | ||
label_idx = self.target_transform[label] | ||
else: | ||
label_idx = self.target_transform(label) | ||
else: | ||
label_idx = int(label) | ||
return label_idx | ||
|
||
def __len__(self): | ||
return len(self.data) |
31 changes: 31 additions & 0 deletions
31
dataloaders/.ipynb_checkpoints/setup_dataloader_smallgan-checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import glob | ||
from .ImageListDataset import ImageListDataset | ||
from torchvision import transforms | ||
from torch.utils.data import DataLoader | ||
|
||
def setup_dataloader(name,h=128,w=128,batch_size=4,num_workers=4): | ||
''' | ||
instead of setting up dataloader that read raw image from file, | ||
let's use store all images on cpu memmory | ||
because this is for small dataset | ||
''' | ||
if name == "face": | ||
img_path_list = glob.glob("./data/face/*.png") | ||
elif name=="anime": | ||
img_path_list = glob.glob("./data/anime/*.png") | ||
else: | ||
raise NotImplementedError("Unknown dataset %s"%name) | ||
|
||
assert len(img_path_list) > 0 | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize( min(h,w) ), | ||
transforms.CenterCrop( (h,w) ), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
img_path_list = [[path,i] for i,path in enumerate(sorted(img_path_list))] | ||
dataset = ImageListDataset(img_path_list,transform=transform) | ||
|
||
return DataLoader([data for data in dataset],batch_size=batch_size, | ||
shuffle=True,num_workers=num_workers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
import os | ||
|
||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import transforms | ||
|
||
from PIL import Image | ||
def pil_loader(path): | ||
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('RGB') | ||
|
||
class ImageListDataset(Dataset): | ||
def __init__(self, path_label_list, img_root=None, | ||
transform=None, | ||
target_transform=None,label_exist = True, | ||
loader=pil_loader): | ||
self.img_root = img_root | ||
self.data = path_label_list | ||
self.label_exist = label_exist | ||
if self.label_exist==False: | ||
self.data = [ [item] for item in path_label_list] | ||
|
||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.loader = loader | ||
|
||
def __getitem__(self, i): | ||
''' | ||
if label exists, get (img,label_idx) pair of i-th data point | ||
if label does not exit, just return image tensor of i-th data point | ||
img is already preprocessed | ||
label_idx start from 0 incrementally so can be used for cnn input directly | ||
''' | ||
if self.label_exist: | ||
return self.get_img(i), self.get_label_idx(i) | ||
else: | ||
return self.get_img(i) | ||
|
||
def get_img_path(self,i): | ||
''' | ||
get img_path of i-th data point | ||
''' | ||
img_path = self.data[i][0] | ||
if self.img_root is not None: | ||
img_path = os.path.join(self.img_root, img_path) | ||
return img_path | ||
|
||
def get_img(self,i): | ||
''' | ||
get img array of i-th data point | ||
self.transform is applied if exists | ||
''' | ||
img = self.loader(self.get_img_path(i)) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
return img | ||
|
||
def get_label(self,i): | ||
''' | ||
get label of i-th data point as it is. | ||
''' | ||
assert self.label_exist | ||
return self.data[i][1] | ||
|
||
def get_label_idx(self,i): | ||
''' | ||
get label idx, which start from 0 incrementally | ||
self.target_transform is applied if exists | ||
''' | ||
label = self.get_label(i) | ||
if self.target_transform is not None: | ||
if isinstance(self.target_transform, dict): | ||
label_idx = self.target_transform[label] | ||
else: | ||
label_idx = self.target_transform(label) | ||
else: | ||
label_idx = int(label) | ||
return label_idx | ||
|
||
def __len__(self): | ||
return len(self.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import glob | ||
from .ImageListDataset import ImageListDataset | ||
from torchvision import transforms | ||
from torch.utils.data import DataLoader | ||
|
||
def setup_dataloader(name,h=128,w=128,batch_size=4,num_workers=4): | ||
''' | ||
instead of setting up dataloader that read raw image from file, | ||
let's use store all images on cpu memmory | ||
because this is for small dataset | ||
''' | ||
if name == "face": | ||
img_path_list = glob.glob("./data/face/*.png") | ||
elif name=="anime": | ||
img_path_list = glob.glob("./data/anime/*.png") | ||
else: | ||
raise NotImplementedError("Unknown dataset %s"%name) | ||
|
||
assert len(img_path_list) > 0 | ||
|
||
transform = transforms.Compose([ | ||
transforms.Resize( min(h,w) ), | ||
transforms.CenterCrop( (h,w) ), | ||
transforms.ToTensor(), | ||
]) | ||
|
||
img_path_list = [[path,i] for i,path in enumerate(sorted(img_path_list))] | ||
dataset = ImageListDataset(img_path_list,transform=transform) | ||
|
||
return DataLoader([data for data in dataset],batch_size=batch_size, | ||
shuffle=True,num_workers=num_workers) |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import torch | ||
import torchvision | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .Vgg16PerceptualLoss import Vgg16PerceptualLoss | ||
|
||
class AdaBIGGANLoss(nn.Module): | ||
def __init__(self,perceptual_loss = "vgg", | ||
scale_per=0.001, | ||
scale_emd=0.1, | ||
scale_reg=0.02, | ||
normalize_img = True, | ||
normalize_per = False, | ||
dist_per = "l1", | ||
): | ||
''' | ||
perceptual_loss: preceptural loss | ||
perceptual_facter: | ||
''' | ||
super(AdaBIGGANLoss,self).__init__() | ||
if perceptual_loss == "vgg": | ||
self.perceptual_loss = Vgg16PerceptualLoss(loss_func=dist_per) | ||
else: | ||
self.perceptual_loss = perceptual_loss | ||
self.scale_per = scale_per | ||
self.scale_emd = scale_emd | ||
self.scale_reg = scale_reg | ||
self.normalize_img = normalize_img | ||
self.normalize_perceptural = normalize_per | ||
|
||
def earth_mover_dist(self,z): | ||
""" | ||
taken from https://github.com/nogu-atsu/SmallGAN/blob/f604cd17516963d8eec292f3faddd70c227b609a/gen_models/ada_generator.py#L150-L162 | ||
earth mover distance between z and standard normal distribution | ||
""" | ||
dim_z = z.shape[1] | ||
n = z.shape[0]#batchsize | ||
t = torch.randn((n * 10,dim_z),device=z.device) | ||
dot = torch.matmul(z, t.permute(-1, -2)) | ||
|
||
#in the original implementation transb=True | ||
#so we want to do t = t.swapaxes(-1, -2) | ||
#from https://github.com/chainer/chainer/blob/c2cf7fb9c49cf98a94caf453f644d612ace45625/chainer/functions/math/matmul.py#L | ||
#then swapaxes is .permute | ||
#from https://discuss.pytorch.org/t/swap-axes-in-pytorch/970 | ||
|
||
dist = torch.sum(z ** 2, dim=1, keepdim=True) - 2 * dot + torch.sum(t ** 2, dim=1) | ||
|
||
return torch.mean(dist.min(dim=0)[0]) + torch.mean(dist.min(dim=1)[0]) | ||
|
||
def l1_reg(self,W): | ||
#https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L146-L148 | ||
#NOTE: I think this should be implemented as weight decay in the optimizer. It's not beatiful code to pass W into loss function. | ||
return torch.mean( W ** 2 ) | ||
|
||
def forward(self,x,y,z,W): | ||
#from IPython import embed;embed() | ||
''' | ||
x:generated image. shape is (batch,channel,h,w) | ||
y:target image. shape is (batch,channel,h,w) | ||
z: seed image embeddings (BEFORE adding the noise of eps). shape is (batch,embedding_dim) | ||
W: model.linear.weight | ||
see the equation (3) in the paper | ||
''' | ||
|
||
# F.mse_loss is L2 loss | ||
# F.l1_loss is L1 loss | ||
|
||
#pytorch regards an image as a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] | ||
#(see transforms.ToTensor() for details) | ||
#but the model output uses tanh so x is ranging (-1 to 1) | ||
#so let's rescale y to (-1 to 1) from (0 to 1) | ||
#chainer implementation use (-1,1) for loss computation, so i didn't do the other way around (i.e. scale x to (0,1)) | ||
image_loss = F.l1_loss(x, 2.0*(y - 0.5) ) | ||
if self.normalize_img: | ||
loss = image_loss/image_loss.item() | ||
else: | ||
loss = image_loss | ||
#rescaled to 1 in the chainer code | ||
#see https://github.com/nogu-atsu/SmallGAN/blob/2293700dce1e2cd97e25148543532814659516bd/gen_models/ada_generator.py#L68-L69 | ||
|
||
for ploss in self.perceptual_loss(img1=x,img2=y,img1_minmax=(-1,1),img2_minmax=(0,1)): | ||
if self.normalize_perceptural: | ||
loss += self.scale_per*ploss/ploss.item() | ||
else: | ||
loss += self.scale_per*ploss | ||
|
||
loss += self.scale_emd*self.earth_mover_dist(z) | ||
|
||
loss += self.scale_reg*self.l1_reg(W) | ||
|
||
return loss |
Oops, something went wrong.