-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_models.py
66 lines (54 loc) · 2.31 KB
/
load_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
def load_generator(g_params=None, is_g_clone=False, ckpt_dir=None, custom_cuda=True):
if custom_cuda:
from stylegan2.generator import Generator
else:
from stylegan2_ref.generator import Generator
if g_params is None:
g_params = {
'z_dim': 512,
'w_dim': 512,
'labels_dim': 0,
'n_mapping': 8,
'resolutions': [4, 8, 16, 32, 64, 128, 256],
'featuremaps': [512, 512, 512, 512, 512, 256, 128],
}
test_latent = tf.ones((1, g_params['z_dim']), dtype=tf.float32)
test_labels = tf.ones((1, g_params['labels_dim']), dtype=tf.float32)
# build generator model
generator = Generator(g_params)
_ = generator([test_latent, test_labels])
if ckpt_dir is not None:
if is_g_clone:
ckpt = tf.train.Checkpoint(g_clone=generator)
else:
ckpt = tf.train.Checkpoint(generator=generator)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1)
ckpt.restore(manager.latest_checkpoint).expect_partial()
if manager.latest_checkpoint:
print(f'Generator restored from {manager.latest_checkpoint}')
return generator
def load_discriminator(d_params=None, ckpt_dir=None, custom_cuda=True):
if custom_cuda:
from stylegan2.discriminator import Discriminator
else:
from stylegan2_ref.discriminator import Discriminator
if d_params is None:
d_params = {
'labels_dim': 0,
'resolutions': [4, 8, 16, 32, 64, 128, 256],
'featuremaps': [512, 512, 512, 512, 512, 256, 128],
}
res = d_params['resolutions'][-1]
test_images = tf.ones((1, 3, res, res), dtype=tf.float32)
test_labels = tf.ones((1, d_params['labels_dim']), dtype=tf.float32)
# build discriminator model
discriminator = Discriminator(d_params)
_ = discriminator([test_images, test_labels])
if ckpt_dir is not None:
ckpt = tf.train.Checkpoint(discriminator=discriminator)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1)
ckpt.restore(manager.latest_checkpoint).expect_partial()
if manager.latest_checkpoint:
print('Discriminator restored from {}'.format(manager.latest_checkpoint))
return discriminator