forked from nikitakit/sabertooth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_pretraining.py
253 lines (217 loc) · 8.89 KB
/
run_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright 2020 The Sabertooth Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run masked LM/next sentence masked_lm pre-training for BERT."""
import datetime
import glob
import itertools
import json
import os
import shutil
import jax
import jax.numpy as jnp
import numpy as np
from absl import app, flags
from flax.training import checkpoints
from ml_collections.config_flags import config_flags
from tensorflow.io import gfile
import data
import modeling
import training
import pdb
FLAGS = flags.FLAGS
flags.DEFINE_string(
"output_dir",
None,
"The output directory where the model checkpoints will be written.",
)
config_flags.DEFINE_config_file("config", None, "Hyperparameter configuration")
def get_output_dir(config):
"""Get output directory location."""
del config
output_dir = FLAGS.output_dir
if output_dir is None:
output_name = "pretrain_{timestamp}".format(
timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"),
)
output_dir = os.path.join("~", "sabertooth", "output", output_name)
output_dir = os.path.expanduser(output_dir)
print()
print("No --output_dir specified")
print("Using default output_dir:", output_dir, flush=True)
return output_dir
def get_initial_params(model, switch, init_checkpoint=None):
if init_checkpoint:
return model.params_from_checkpoint(model, init_checkpoint)
else:
def initialize_model():
dummy_input = jnp.zeros((1, 128), dtype=jnp.int32)
return model.init(
jax.random.PRNGKey(np.random.randint(2**16)),
input_ids=dummy_input,
input_mask=dummy_input,
type_ids=dummy_input,
masked_lm_positions=dummy_input,
switch=switch,
deterministic=True,
)
variable_dict = jax.jit(initialize_model)()
return variable_dict["params"]
def compute_pretraining_loss_and_metrics(apply_fn, variables, batch, step, rngs):
"""Compute cross-entropy loss for classification tasks."""
metrics = apply_fn(
variables,
batch["input_ids"],
batch["input_mask"],
batch["token_type_ids"],
batch["masked_lm_positions"],
batch["masked_lm_ids"],
batch["masked_lm_weights"],
batch["next_sentence_label"],
step,
rngs=rngs,
)
return metrics["loss"], metrics
def compute_pretraining_stats(apply_fn, variables, batch):
"""Used for computing eval metrics during pre-training."""
masked_lm_logits, next_sentence_logits = apply_fn(
variables,
batch["input_ids"],
batch["input_mask"],
batch["token_type_ids"],
batch["masked_lm_positions"],
deterministic=True,
)
stats = modeling.BertForPreTraining.compute_metrics(
masked_lm_logits,
next_sentence_logits,
batch["masked_lm_ids"],
batch["masked_lm_weights"],
batch["next_sentence_label"],
)
masked_lm_correct = jnp.sum(
(masked_lm_logits.argmax(-1) == batch["masked_lm_ids"].reshape((-1,)))
* batch["masked_lm_weights"].reshape((-1,))
)
next_sentence_labels = batch["next_sentence_label"].reshape((-1,))
next_sentence_correct = jnp.sum(
next_sentence_logits.argmax(-1) == next_sentence_labels
)
stats = {
"masked_lm_correct": masked_lm_correct,
"masked_lm_total": jnp.sum(batch["masked_lm_weights"]),
"next_sentence_correct": next_sentence_correct,
"next_sentence_total": jnp.sum(jnp.ones_like(next_sentence_labels)),
**stats,
}
return stats
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
config = FLAGS.config
input_files = sum([glob.glob(pattern) for pattern in config.input_files], [])
assert input_files, "No input files!"
print(f"Training with {len(input_files)} input files, including:")
print(f" - {input_files[0]}")
model = modeling.BertForPreTraining(config=config.model)
if config.model.attention_type == "LinEVAMHA":
initial_params = get_initial_params(model, True, init_checkpoint=config.init_checkpoint)
else:
initial_params = get_initial_params(model, False, init_checkpoint=config.init_checkpoint)
tx = training.create_optimizer(
optimizer=config.optimizer,
b1=config.adam_beta1,
b2=config.adam_beta2,
eps=config.adam_epsilon,
weight_decay=config.weight_decay,
max_grad_norm=config.max_grad_norm,
learning_rate=config.learning_rate,
warmup_steps=config.num_warmup_steps,
total_steps=config.num_train_steps,
)
state = training.TrainState.create(
apply_fn=model.apply,
params=initial_params,
tx=tx,
train_rngs={"dropout": jax.random.PRNGKey(np.random.randint(2**16))},
history=training.MetricHistory(),
)
del initial_params # the state takes ownership of all params
output_dir = get_output_dir(config)
gfile.makedirs(output_dir)
# Restore from a local checkpoint, if one exists.
state = checkpoints.restore_checkpoint(output_dir, state)
start_step = int(state.step)
state = state.replicate()
data_pipeline = data.PretrainingDataPipeline(
sum([glob.glob(pattern) for pattern in config.input_files], []),
config.tokenizer,
max_seq_length=config.max_seq_length,
max_predictions_per_seq=config.max_predictions_per_seq,
)
if config.do_train:
train_batch_size = config.train_batch_size
if jax.process_count() > 1:
assert (
train_batch_size % jax.process_count() == 0
), "train_batch_size must be divisible by number of processes"
train_batch_size = train_batch_size // jax.process_count()
train_iter = data_pipeline.get_inputs(
batch_size=train_batch_size, training=True
)
train_step_fn = training.create_train_step(compute_pretraining_loss_and_metrics)
for step, batch in zip(range(start_step, config.num_train_steps), train_iter):
state = train_step_fn(state, batch, False if step < (0.7 * 120000) else True)
if jax.process_index() == 0 and (
step % config.save_checkpoints_steps == 0
or step == config.num_train_steps - 1
):
checkpoints.save_checkpoint(output_dir, state.unreplicate(), step)
config_path = os.path.join(output_dir, "config.json")
if not os.path.exists(config_path):
with open(config_path, "w") as f:
json.dump({"model_type": "bert", **config.model}, f)
tokenizer_path = os.path.join(output_dir, "sentencepiece.model")
if not os.path.exists(tokenizer_path):
shutil.copy(config.tokenizer, tokenizer_path)
# With the current Rust data pipeline code, running more than one pipeline
# at a time will lead to a hang. A simple workaround is to fully delete the
# training pipeline before potentially starting another for evaluation.
del train_iter
if config.do_eval:
eval_iter = data_pipeline.get_inputs(batch_size=config.eval_batch_size)
eval_iter = itertools.islice(eval_iter, config.max_eval_steps)
eval_fn = training.create_eval_fn(
compute_pretraining_stats, sample_feature_name="input_ids"
)
eval_stats = eval_fn(state, eval_iter)
eval_metrics = {
"loss": jnp.mean(eval_stats["loss"]),
"masked_lm_loss": jnp.mean(eval_stats["masked_lm_loss"]),
"next_sentence_loss": jnp.mean(eval_stats["next_sentence_loss"]),
"masked_lm_accuracy": jnp.sum(eval_stats["masked_lm_correct"])
/ jnp.sum(eval_stats["masked_lm_total"]),
"next_sentence_accuracy": jnp.sum(eval_stats["next_sentence_correct"])
/ jnp.sum(eval_stats["next_sentence_total"]),
}
eval_results = []
for name, val in sorted(eval_metrics.items()):
line = f"{name} = {val:.06f}"
print(line, flush=True)
eval_results.append(line)
eval_results_path = os.path.join(output_dir, "eval_results.txt")
with gfile.GFile(eval_results_path, "w") as f:
for line in eval_results:
f.write(line + "\n")
if __name__ == "__main__":
app.run(main)