Skip to content

Commit

Permalink
fix implementation, Q and R is shared across iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 5, 2020
1 parent 633afff commit 96686f2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions mogrifier/mogrifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class Mogrifier(nn.Module):
def __init__(self, dim, iters = 5, factorize_k = None):
super().__init__()
self.dim = dim
self.weights = nn.ModuleList([weight(dim, dim, factorize_k) for _ in range(iters)])
self.iters = iters

self.Q = weight(dim, dim, factorize_k)
self.R = weight(dim, dim, factorize_k) if iters > 1 else None

def forward(self, x, h):
shape = x.shape
Expand All @@ -25,11 +28,11 @@ def forward(self, x, h):

x, h = map(lambda t: t.reshape(-1, dim), (x, h))

for ind, W in enumerate(self.weights):
for ind in range(self.iters):
if (ind % 2) == 0:
x = 2 * W(h).sigmoid() * x
x = 2 * self.Q(h).sigmoid() * x
else:
h = 2 * W(x).sigmoid() * h
h = 2 * self.R(x).sigmoid() * h

x, h = map(lambda t: t.reshape(*shape), (x, h))
return x, h
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'mogrifier',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Implementation of Mogrifier circuit from Deepmind',
author = 'Phil Wang',
Expand Down

0 comments on commit 96686f2

Please sign in to comment.