-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
131 lines (111 loc) · 6.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import numpy as np
import argparse
import os
import sys
import time
import datetime
from DETSMCL import DETSMCL
import tasks
import datautils
from utils import init_dl_program, name_with_datetime, pkl_save, data_dropout
def save_checkpoint_callback(
save_every=1,
unit='epoch'
):
assert unit in ('epoch', 'iter')
def callback(model, loss):
n = model.n_epochs if unit == 'epoch' else model.n_iters
if n % save_every == 0:
model.save(f'{run_dir}/model_{n}.pkl')
return callback
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('dataset', help='The dataset name')
parser.add_argument('run_name', help='The folder name used to save model, output and evaluation metrics. This can be set to any word')
parser.add_argument('--loader', type=str, required=True, help='The data loader used to load the experimental data. This can be set to UCR, UEA, forecast_csv, forecast_csv_univar, anomaly, or anomaly_coldstart')
parser.add_argument('--gpu', type=int, default=0, help='The gpu no. used for training and inference (defaults to 0)')
parser.add_argument('--batch-size', type=int, default=4, help='The batch size (defaults to 8)')
parser.add_argument('--lr', type=float, default=0.001, help='The learning rate (defaults to 0.001)')
parser.add_argument('--repr-dims', type=int, default=320, help='The representation dimension (defaults to 320)')
parser.add_argument('--max-train-length', type=int, default=3000, help='For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length> (defaults to 3000)')
parser.add_argument('--iters', type=int, default=None, help='The number of iterations')
parser.add_argument('--epochs', type=int, default=None, help='The number of epochs')
parser.add_argument('--save-every', type=int, default=None, help='Save the checkpoint every <save_every> iterations/epochs')
parser.add_argument('--seed', type=int, default=None, help='The random seed')
parser.add_argument('--max-threads', type=int, default=None, help='The maximum allowed number of threads used by this process')
parser.add_argument('--eval', action="store_true", help='Whether to perform evaluation after training')
parser.add_argument('--irregular', type=float, default=0, help='The ratio of missing observations (defaults to 0)')
parser.add_argument('--momentu',type=float,default=0.999,help='moco momentu')
args = parser.parse_args()
print("Dataset:", args.dataset)
print("Arguments:", str(args))
device = init_dl_program(args.gpu, seed=args.seed, max_threads=args.max_threads)
print('Loading data... ', end='')
if args.loader == 'forecast_csv':
task_type = 'forecasting'
data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset)
train_data = data[:, train_slice]
elif args.loader == 'forecast_csv_univar':
task_type = 'forecasting'
data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_csv(args.dataset, univar=True)
train_data = data[:, train_slice]
elif args.loader == 'forecast_npy':
task_type = 'forecasting'
data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_npy(args.dataset)
train_data = data[:, train_slice]
elif args.loader == 'forecast_npy_univar':
task_type = 'forecasting'
data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols = datautils.load_forecast_npy(args.dataset, univar=True)
train_data = data[:, train_slice]
else:
raise ValueError(f"Unknown loader {args.loader}.")
if args.irregular > 0:
if task_type == 'classification':
train_data = data_dropout(train_data, args.irregular)
test_data = data_dropout(test_data, args.irregular)
else:
raise ValueError(f"Task type {task_type} is not supported when irregular>0.")
print('done')
config = dict(
batch_size=args.batch_size,
lr=args.lr,
output_dims=args.repr_dims,
max_train_length=args.max_train_length
)
if args.save_every is not None:
unit = 'epoch' if args.epochs is not None else 'iter'
config[f'after_{unit}_callback'] = save_checkpoint_callback(args.save_every, unit)
run_dir = 'training/' + args.dataset + '__' + name_with_datetime(args.run_name)
os.makedirs(run_dir, exist_ok=True)
t = time.time()
model = DETSMCL(
input_dims=train_data.shape[-1],
device=device,
**config #将解压字典config,并使用该字典中的键值对作为network的参数
)
loss_log = model.fit(
train_data,
n_epochs=args.epochs,
n_iters=args.iters,
verbose=True
)
model.save(f'{run_dir}/model.pkl')
t = time.time() - t
print(f"\nTraining time: {datetime.timedelta(seconds=t)}\n")
if args.eval:
if task_type == 'classification':
out, eval_res = tasks.eval_classification(model, train_data, train_labels, test_data, test_labels, eval_protocol='svm')
elif task_type == 'forecasting':
out, eval_res = tasks.eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols)
elif task_type == 'anomaly_detection':
out, eval_res = tasks.eval_anomaly_detection(model, all_train_data, all_train_labels, all_train_timestamps, all_test_data, all_test_labels, all_test_timestamps, delay)
elif task_type == 'anomaly_detection_coldstart':
out, eval_res = tasks.eval_anomaly_detection_coldstart(model, all_train_data, all_train_labels, all_train_timestamps, all_test_data, all_test_labels, all_test_timestamps, delay)
else:
assert False
pkl_save(f'{run_dir}/out.pkl', out)
pkl_save(f'{run_dir}/eval_res.pkl', eval_res)
print('Evaluation result:\n''\n', eval_res)
print('\n')
print("Finished.")