-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
executable file
·120 lines (102 loc) · 2.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys, os
import numpy as np
import torch
os.environ['DDEBACKEND'] = 'pytorch'
import deepxde
import mre_pinn
from mre_pinn.utils import main
from mre_pinn.training.losses import msae_loss
@main
def train(
# data settings
xarray_dir='data/BIOQIC/fem_box',
example_id='60',
frequency='auto',
noise_ratio=0.0,
anatomical=False,
# pde settings
pde_name='hetero',
# baseline settings
savgol_filter=False,
# model settings
omega=30,
n_layers=5,
n_hidden=128,
activ_fn='ss',
polar_input=False,
# training settings
optimizer='adam',
learning_rate=1e-4,
u_loss_weight=1.0,
mu_loss_weight=0.0,
a_loss_weight=0.0,
pde_loss_weight=1e-16,
pde_warmup_iters=10000,
pde_init_weight=1e-18,
pde_step_iters=5000,
pde_step_factor=10,
n_points=1024,
n_iters=100000,
# testing settings
test_every=1000,
save_every=10000,
save_prefix=None
):
# load the training data
example = mre_pinn.data.MREExample.load_xarrays(
xarray_dir=xarray_dir,
example_id=example_id,
anat=anatomical
)
if frequency == 'auto': # infer from data
frequency = float(example.wave.frequency.item())
else:
frequency = float(frequency)
if noise_ratio > 0:
example.add_gaussian_noise(noise_ratio)
mre_pinn.baseline.eval_ahi_baseline(
example, frequency=frequency, savgol_filter=savgol_filter
)
mre_pinn.baseline.eval_fem_baseline(
example,
frequency=frequency,
hetero=(pde_name == 'hetero'),
savgol_filter=savgol_filter
)
# define PDE that we want to solve
pde = mre_pinn.pde.WaveEquation.from_name(
pde_name, omega=frequency, detach=True
)
# define the model architecture
pinn = mre_pinn.model.MREPINN(
example,
omega=omega,
n_layers=n_layers,
n_hidden=n_hidden,
polar_input=polar_input
)
print(pinn)
# compile model and configure training settings
model = mre_pinn.training.MREPINNModel(
example, pinn, pde,
loss_weights=[u_loss_weight, mu_loss_weight, a_loss_weight, pde_loss_weight],
pde_warmup_iters=pde_warmup_iters,
pde_step_iters=pde_step_iters,
pde_step_factor=pde_step_factor,
pde_init_weight=pde_init_weight,
n_points=n_points
)
model.compile(optimizer='adam', lr=learning_rate, loss=msae_loss)
model.benchmark(100)
test_eval = mre_pinn.testing.TestEvaluator(
test_every=test_every,
save_every=save_every,
save_prefix=save_prefix
)
# train the model
model.train(n_iters, display_every=10, callbacks=[test_eval])
# final test evaluation
print('Final test evaluation')
test_eval.test()
print(test_eval.metrics)
print('Done')