-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
64 lines (54 loc) · 1.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from datasets import concatenate_datasets, Dataset
from training import utils
from training.train import Trainer
import argparse
import glob
parser = argparse.ArgumentParser(
description="Process the dataset and train the model",
)
parser.add_argument(
"--process_ds_path",
type=str,
default=None,
help="Path to the dataset preprocessed",
)
parser.add_argument(
"--chunked_ds_path",
type=str,
default=None,
help="Path to the dataset chunked processed by this script",
)
parser.add_argument(
"--model_path",
type=str,
default="./model",
help="Path to save the model",
)
args = parser.parse_args()
if args.process_ds_path:
dataset = utils.gather_dataset(args.process_ds_path)
chuck_ds = []
trainer = Trainer(dataset)
i = 0
for i in range(len(dataset) // 1000):
ds = trainer.process_dataset(dataset, i)
ds.save_to_disk(f"./dataset/process/{i}")
chuck_ds.append(ds)
ds = trainer.process_dataset(dataset, -1)
ds.save_to_disk(f"./dataset/process/{i+1}")
chuck_ds.append(ds)
dataset = concatenate_datasets(chuck_ds)
trainer.dataset = dataset.train_test_split(test_size=0.05)
elif args.chunked_ds_path:
chuck_ds = []
nb_chunks = len(glob.glob(f"{args.chunked_ds_path}/*"))
for i in range(nb_chunks):
ds = Dataset.load_from_disk(f"{args.chunked_ds_path}/{i}")
chuck_ds.append(ds)
dataset = concatenate_datasets(chuck_ds)
dataset = dataset.train_test_split(test_size=0.05)
trainer = Trainer(dataset)
else:
raise ValueError("You must provide either --process_ds_path or --chunked_ds_path")
trainer.train()
trainer.save_model(args.model_path)