-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptim.py
36 lines (26 loc) · 918 Bytes
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import List
try:
from .parameter import Parameter
from .autograd import zero_grad
except:
from parameter import Parameter
from autograd import zero_grad
class Optimizer(object):
def __init__(self, parameters: List[Parameter]) -> None:
self.parameters = parameters
def zero_grad(self):
raise NotImplementedError
def step(self):
raise NotImplementedError
def load_parameters(self, parameters):
self.parameters = parameters
class SGD(Optimizer):
def __init__(self, parameters: List[Parameter], lr: float = 0.01) -> None:
super().__init__(parameters)
self.lr = lr
def zero_grad(self):
for parameter in self.parameters:
zero_grad(parameter)
def step(self):
for parameter in self.parameters:
parameter.data = parameter.data - self.lr * parameter.gradient