-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtrain.py
105 lines (93 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import pickle
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # suppress info-level logs
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from dataset import prepare_dataset
from augmentations import RandomResizedCrop, RandomColorJitter
from algorithms import SimCLR, NNCLR, DCCLR, BarlowTwins, HSICTwins, TWIST, MoCo, DINO
tf.get_logger().setLevel("WARN") # suppress info-level logs
# hyperparameters
num_epochs = 30
steps_per_epoch = 200
width = 128
# hyperparameters corresponding to each algorithm
hyperparams = {
SimCLR: {"temperature": 0.1},
NNCLR: {"temperature": 0.1, "queue_size": 10000},
DCCLR: {"temperature": 0.1},
BarlowTwins: {"redundancy_reduction_weight": 10.0},
HSICTwins: {"redundancy_reduction_weight": 3.0},
TWIST: {},
MoCo: {"momentum_coeff": 0.99, "temperature": 0.1, "queue_size": 10000},
DINO: {"momentum_coeff": 0.9, "temperature": 0.1, "sharpening": 0.5},
}
# load STL10 dataset
batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset(
steps_per_epoch
)
# select an algorithm
Algorithm = SimCLR
# architecture
model = Algorithm(
contrastive_augmenter=keras.Sequential(
[
layers.Input(shape=(96, 96, 3)),
preprocessing.Rescaling(1 / 255),
preprocessing.RandomFlip("horizontal"),
RandomResizedCrop(scale=(0.2, 1.0), ratio=(3 / 4, 4 / 3)),
RandomColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
],
name="contrastive_augmenter",
),
classification_augmenter=keras.Sequential(
[
layers.Input(shape=(96, 96, 3)),
preprocessing.Rescaling(1 / 255),
preprocessing.RandomFlip("horizontal"),
RandomResizedCrop(scale=(0.5, 1.0), ratio=(3 / 4, 4 / 3)),
RandomColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
],
name="classification_augmenter",
),
encoder=keras.Sequential(
[
layers.Input(shape=(96, 96, 3)),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Flatten(),
layers.Dense(width, activation="relu"),
],
name="encoder",
),
projection_head=keras.Sequential(
[
layers.Input(shape=(width,)),
layers.Dense(width, activation="relu"),
layers.Dense(width),
],
name="projection_head",
),
linear_probe=keras.Sequential(
[
layers.Input(shape=(width,)),
layers.Dense(10),
],
name="linear_probe",
),
**hyperparams[Algorithm],
)
# optimizers
model.compile(
contrastive_optimizer=keras.optimizers.Adam(),
probe_optimizer=keras.optimizers.Adam(),
)
# run training
history = model.fit(train_dataset, epochs=num_epochs, validation_data=test_dataset)
# save history
with open("{}.pkl".format(Algorithm.__name__), "wb") as write_file:
pickle.dump(history.history, write_file)