-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathoptimizer.py
32 lines (27 loc) · 1.02 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pydgn.training.callback.optimizer import Optimizer
from pydgn.training.event.handler import EventHandler
class CGMMOptimizer(Optimizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def on_eval_epoch_start(self, state):
"""
Use the "return_node_embeddings" field of the state to decide whether to compute statistics or not during
this evaluation epoch
:param state: the shared State object
"""
state.model.return_node_embeddings = state.return_node_embeddings
# Not necessary, but it may help to debug
def on_eval_epoch_end(self, state):
"""
Reset the "return_node_embeddings" field to False
:param state:
:return:
"""
state.model.return_node_embeddings = False
def on_training_epoch_end(self, state):
"""
Calls the M_step to update the parameters
:param state: the shared State object
:return:
"""
state.model.m_step()