diff --git a/mogrifier/mogrifier.py b/mogrifier/mogrifier.py index c00068c..6e94dc7 100644 --- a/mogrifier/mogrifier.py +++ b/mogrifier/mogrifier.py @@ -16,7 +16,10 @@ class Mogrifier(nn.Module): def __init__(self, dim, iters = 5, factorize_k = None): super().__init__() self.dim = dim - self.weights = nn.ModuleList([weight(dim, dim, factorize_k) for _ in range(iters)]) + self.iters = iters + + self.Q = weight(dim, dim, factorize_k) + self.R = weight(dim, dim, factorize_k) if iters > 1 else None def forward(self, x, h): shape = x.shape @@ -25,11 +28,11 @@ def forward(self, x, h): x, h = map(lambda t: t.reshape(-1, dim), (x, h)) - for ind, W in enumerate(self.weights): + for ind in range(self.iters): if (ind % 2) == 0: - x = 2 * W(h).sigmoid() * x + x = 2 * self.Q(h).sigmoid() * x else: - h = 2 * W(x).sigmoid() * h + h = 2 * self.R(x).sigmoid() * h x, h = map(lambda t: t.reshape(*shape), (x, h)) return x, h diff --git a/setup.py b/setup.py index 73f6499..abf2b74 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'mogrifier', packages = find_packages(), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Implementation of Mogrifier circuit from Deepmind', author = 'Phil Wang',