-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
137 lines (114 loc) · 5.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import datetime
import io
import os
import numpy as np
import torch
import tensorflow as tf
import torch.optim as optim
import tqdm
from cache_tensorboard import log_hit_rates, log_scalar
from common.utils import create_directory
from evaluator import cache_hit_rate_evaluator, evaluate
from utils import as_batches, save_pickle
from cache_policy_model import CachePolicyModel
from configuration import config
from generator import train_data_generator
from baselines.common import schedules
def schedule_from_config(config):
"""Create a schedule from a configuration dictionary."""
if config["type"] == "linear":
return schedules.LinearSchedule(config["num_steps"], config["final"], config["initial"])
elif config["type"] == "constant":
return schedules.ConstantSchedule(config["value"])
else:
raise ValueError(f"Unknown schedule type: {config['type']}")
def main():
# Create a datetime string for saving models
experiment_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
print("Experiment ID:", experiment_id)
# Create experiment directory
experiment_dir = os.path.join(config["experiment"]["base_dir"], 'tensorboard', experiment_id)
create_directory(experiment_dir, overwrite=True)
# Create tensorboard writer
tb_writer = tf.summary.create_file_writer(experiment_dir)
update_frequency = config["dagger_schedule"]["update_frequency"]
batch_size = config["training"]["batch_size"]
collection_multiplier = config["training"]["collection_multiplier"]
max_examples = (update_frequency * batch_size * collection_multiplier)
# Create dagger schedule
dagger_schedule = schedule_from_config(config["dagger_schedule"])
# Process everything on GPU if available
device = torch.device("cpu")
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda:0")
# elif torch.backends.mps.is_available():
# device = torch.device("mps")
# torch.set_default_device(device)
print("Device:", device)
# Initialize the model and optimizer
model = CachePolicyModel.from_config(config["model"]).to(device)
optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
# Initialize the step counter
step = 0
get_step = lambda: step
total_steps = config["training"]["total_steps"]
# Create the checkpoint directory
checkpoint_dir = os.path.join(config["training"]["checkpoint_dir"], experiment_id)
create_directory(checkpoint_dir)
# Save the configuration
config_save_path = os.path.join(checkpoint_dir, "config.pkl")
save_pickle(config, config_save_path)
with tqdm.tqdm(total=total_steps) as pbar:
# Create training datasets generator
training_datasets = train_data_generator(config["dataset"],
dagger_schedule,
get_step,
model,
max_examples)
# Train the model
for dataset, cache_hit_rates in training_datasets:
# Log the hit rates
# log_hit_rates("cache_hit_rates/train_belady_policy", cache_hit_rates, get_step())
print("Training...")
sequence_length = config["training"]["sequence_length"]
warmup_period = sequence_length // 2
# Generate batches from dataset
for batch_num, batch in enumerate(as_batches([dataset], batch_size, sequence_length)):
optimizer.zero_grad()
loss = model.loss(batch, warmup_period)
loss.backward()
optimizer.step()
pbar.update(1)
step += 1
# log the loss
if step % config["training"]["log_loss_frequency"] == 0 and step != 0:
loss_cpu = loss.cpu()
log_scalar(tb_writer, "loss/reuse_distance", loss_cpu.detach().numpy(), step)
print(f"Step: {step}, loss: {loss_cpu.detach().numpy()}")
# Save model
if step % config["training"]["save_frequency"] == 0 and step != 0:
save_path = os.path.join(checkpoint_dir, f"model_{step}.ckpt")
with open(save_path, "wb") as save_file:
checkpoint_buffer = io.BytesIO()
torch.save(model.state_dict(), checkpoint_buffer)
print(f"Saving model at step {step}")
save_file.write(checkpoint_buffer.getvalue())
# Evaluate model
# if step % config["training"]["evaluation_frequency"] == 0 and step != 0:
# hit_rates = next(cache_hit_rate_evaluator(config["dataset"],
# model,
# None,
# config["training"]["evaluation_size"]))
# print(f"Hit rates: {np.mean(hit_rates)}, step: {step}")
# log_hit_rates("cache_hit_rates/train", hit_rates, get_step())
# Break if the step counter exceeds the total number of steps
if step >= total_steps:
return
# Break out of inner loop to get next dataset
if batch_num >= config["dagger_schedule"]["update_frequency"]:
break
# evaluate model
evaluate(experiment_id, multi_process=True)
if __name__ == "__main__":
main()