This repository has been archived by the owner on Sep 26, 2022. It is now read-only.
forked from CEA-LIST/Basket-Ball-Size-Estimation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
83 lines (55 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import random
import os
import torch
from torch import nn
import numpy as np
from model import BallSizeModel
from dataset import BallSizeDataset
SEED = 4212
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.use_deterministic_algorithms(True)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
def train():
os.makedirs("output", exist_ok = True)
learningRate = 0.0001
numEpochs = 200
batchSize = 4
model = BallSizeModel().cuda()
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learningRate)
dataset = BallSizeDataset("basketball-instants-dataset/ball_dataset_trainval.pickle", True)
dataLoader = torch.utils.data.DataLoader(
dataset, batch_size=batchSize, shuffle=True, num_workers=4,
worker_init_fn=seed_worker, generator=g)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(numEpochs * 2 / 3), gamma=0.1)
for epoch in range(numEpochs):
totalLoss = 0
dataIt = 0
for _, _, imgs, ballSizes in dataLoader:
imgs = imgs.cuda()
ballSizes = ballSizes.cuda()
estSize = model(imgs)
loss = criterion(estSize, ballSizes)
#print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
totalLoss += loss.data.item()
dataIt += 1
avgLoss = totalLoss / dataIt
print('epoch [{}/{}], avg loss:{:.6f}'.format(epoch + 1, numEpochs, avgLoss))
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), os.path.join("output", "model_%d.pth" % (epoch + 1)))
scheduler.step()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Ball size training script")
args = parser.parse_args()
train()