-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstyle_transfer.py
300 lines (217 loc) · 9.84 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import argparse
from pathlib import Path
from typing import List, Union
import numpy as np
import torch
from PIL import Image
from torch import Tensor, nn, optim
from torch.nn import functional as F
from torch.optim import optimizer
from torchvision import models
from torchvision import transforms as T
from torchvision.utils import save_image
from tqdm.auto import tqdm
from datetime import datetime as dt
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Gatys et al. variant?
STYLE_LAYERS_DEFAULT = {
'conv1_1': 0.75,
'conv2_1': 0.5,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2,
}
CONTENT_LAYERS_DEFAULT = ('conv5_2', )
CONTENT_WEIGHT = 8 # "alpha" in the literature (default: 8)
STYLE_WEIGHT = 70 # "beta" in the literature (default: 70)
TV_WEIGHT = 10 # (default: 10)
IMG_SIZE = 512
LEARNING_RATE = 0.004
class NormalizeInverse(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor: Tensor) -> Tensor:
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
def get_features(image: Tensor, model:nn.Module, layers=None):
if layers is None:
layers = tuple(STYLE_LAYERS_DEFAULT) + CONTENT_LAYERS_DEFAULT
features = {}
block_num = 1
conv_num = 0
x = image
for layer in model:
x = layer(x)
if isinstance(layer, nn.Conv2d):
# produce layer name to find matching convolutions from the paper
# and store their output for further processing.
conv_num += 1
name = f'conv{block_num}_{conv_num}'
if name in layers:
features[name] = x
elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)):
# In VGG, each block ends with max/avg pooling layer.
block_num += 1
conv_num = 0
elif isinstance(layer, (nn.BatchNorm2d, nn.ReLU)):
pass
else:
raise Exception(f'Unknown layer: {layer}')
return features
def gram_matrix(input: Tensor, normalize=False) -> Tensor:
(b, ch, h, w) = input.size()
# resise F_XL into \hat F_XL
features = input.view(b * ch, h * w)
# compute the gram product
gram = torch.mm(features, features.t())
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
if normalize:
#gram = gram.div(b * ch * h * w)
gram /= input.nelement() # equivalent to: gram = gram.div(b * ch * h * w)
return gram
transform = T.Compose([
# Smaller edge of the image will be matched to `IMG_SIZE`
T.Resize(IMG_SIZE),
T.ToTensor(),
T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])
inv_transform = T.Compose([
NormalizeInverse(IMAGENET_MEAN, IMAGENET_STD),
])
inv_transform_preview = T.Compose([
inv_transform,
T.CenterCrop((IMG_SIZE, IMG_SIZE)),
])
def load_image(img_path: Union[str, Path]) -> Tensor:
image = Image.open(img_path).convert('RGB')
image = transform(image).unsqueeze(0)
return image
def content_loss_func(target_features, content_features):
content_loss = 0.0
for layer in content_features:
target_feature = target_features[layer]
content_feature = content_features[layer]
content_layer_loss = F.mse_loss(target_feature, content_feature)
content_loss += content_layer_loss
return content_loss
def style_loss_func(target_features, style_features, precomputed_style_grams):
style_loss = 0.0
for layer in style_features:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
style_gram = precomputed_style_grams[layer]
_, d, h, w = target_feature.shape
layer_style_loss = STYLE_LAYERS_DEFAULT[layer] * F.mse_loss(target_gram, style_gram)
style_loss += layer_style_loss / (d * h * w)
return style_loss
def total_variance_loss_func(target):
tv_loss = \
F.l1_loss(target[:, :, :, :-1], target[:, :, :, 1:]) + \
F.l1_loss(target[:, :, :-1, :], target[:, :, 1:, :])
return tv_loss
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', metavar='PATH', required=True)
parser.add_argument('--style', metavar='PATH', required=True)
parser.add_argument('--epochs', type=int, metavar='N', default=7000, help='number of train epochs (default: 7000)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disable CUDA acceleration')
parser.add_argument('--optimizer', choices=['adam', 'adamw', 'lbfgs', 'sgd'], default='adam', help='select optimizer (default: adam)')
parser.add_argument('--init', choices=['input', 'noise'], default='input', help='select start image (default: input)')
#parser.add_argument('--output', '-o', help='image output')
#parser.add_argument('--animation', action='store_true', default=False, help='intermediate images into animated GIF')
#parser.add_argument('--no-normalization', action='save_true', default=False, help='disable intermediate image color normalization')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = torch.device('cuda' if args.cuda else 'cpu')
if args.cuda:
# Allow CuDNN internal benchmarking for architecture-specific optimizations
torch.backends.cudnn.benchmark = True
# We will use frozen pretrained VGG neural network for feature extraction
# In the original paper, authors have used VGG19 (without bn)
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
# Authors in the original paper suggested use of AvgPool instead of MaxPool for more pleasing results.
# However changing the pooling also affects activation, so the input needs to be scaled (not implemented).
#for i, layer in enumerate(model):
# if isinstance(layer, torch.nn.MaxPool2d):
# model[i] = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
model = model.eval().requires_grad_(False).to(device)
# The "content" image on which we apply style
content = load_image(args.input).to(device)
# The "style" image from which we obtain style
style = load_image(args.style).to(device)
# The "target" image to store outcome
if args.init == 'input':
target = content.clone().requires_grad_(True).to(device)
elif args.init == 'noise':
target = torch.rand_like(content).requires_grad_(True).to(device)
else:
raise Exception(f'Init type "{args.init}" is not implemented!')
# Precompute content features, style features, and style gram matrices.
content_features = get_features(content, model, CONTENT_LAYERS_DEFAULT)
style_features = get_features(style, model, STYLE_LAYERS_DEFAULT)
style_grams = {
layer: gram_matrix(style_features[layer])
for layer in style_features
}
if args.optimizer.lower() == 'lbfgs':
# LBFGS optimizer has a bit different API from others where it uses closure()
optimizer = optim.LBFGS([target], max_iter=args.epochs, line_search_fn='strong_wolfe')
def closure():
target_features = get_features(target, model)
content_loss = content_loss_func(target_features, content_features)
style_loss = style_loss_func(target_features, style_features, style_grams)
tv_loss = total_variance_loss_func(target)
total_loss = \
CONTENT_WEIGHT * content_loss + \
STYLE_WEIGHT * style_loss + \
TV_WEIGHT * tv_loss
if torch.is_grad_enabled():
optimizer.zero_grad(set_to_none=True)
if total_loss.requires_grad:
total_loss.backward()
return total_loss
optimizer.step(closure)
else:
if args.optimizer.lower() == 'adam':
optimizer = optim.Adam([target], lr=LEARNING_RATE)
elif args.optimizer.lower() == 'sgd':
optimizer = optim.SGD([target], lr=LEARNING_RATE)
elif args.optimizer.lower() == 'adamw':
optimizer = optim.AdamW([target], lr=LEARNING_RATE)
else:
raise Exception(f'Use of optimizer "{args.optimizer}" not implemented!')
pbar = tqdm(range(args.epochs))
for _ in pbar:
optimizer.zero_grad(set_to_none=True)
target_features = get_features(target, model)
content_loss = CONTENT_WEIGHT * content_loss_func(target_features, content_features)
style_loss = STYLE_WEIGHT * style_loss_func(target_features, style_features, style_grams)
tv_loss = TV_WEIGHT * total_variance_loss_func(target)
total_loss = content_loss + style_loss + tv_loss
total_loss.backward(retain_graph=True) # do we need `retain_graph=True`?
optimizer.step()
pbar.set_postfix_str(
f'total_loss={total_loss.item():.2f} '
f'content_loss={content_loss.item():.2f} '
f'style_loss={style_loss.item():.2f} '
f'tv_loss={tv_loss.item():.2f} '
)
#with torch.no_grad():
# target = torch.clamp(target, 0.0, 1.0)
timestamp = dt.now().strftime('%Y%m%dT%H%M%S')
# Store the outcome
save_image(inv_transform(target.detach().squeeze().cpu()), f'./{timestamp}-output.jpg')
# Store content + style + target image for impression
save_image([
inv_transform_preview(content.detach().squeeze().cpu()),
inv_transform_preview(style.detach().squeeze().cpu()),
inv_transform_preview(target.detach().squeeze().cpu()),
], f'./{timestamp}-transition.jpg')