Skip to content

Commit

Permalink
may try it out
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 8, 2024
1 parent fdb234c commit 88b2be4
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 58 deletions.
25 changes: 15 additions & 10 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# This workflows will upload a Python Package using Twine when a release is created
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [created]
types: [published]

jobs:
deploy:
Expand All @@ -21,11 +26,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
20 changes: 12 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,31 @@ $ pip install mogrifier
import torch
from mogrifier import Mogrifier

m = Mogrifier(
mogrify = Mogrifier(
dim = 512,
dim_hidden = 256,
iters = 5, # number of iterations, defaults to 5 as paper recommended for LSTM
factorize_k = 16 # factorize weight matrices into (dim x k) and (k x dim), if specified
)

x = torch.randn(1, 16, 512)
h = torch.randn(1, 16, 512)
h = torch.randn(1, 16, 256)

x_out, h_out = m(x, h) # (1, 16, 512), (1, 16, 512)
out, hidden_out = mogrify(x, h) # (1, 16, 512), (1, 16, 256)

assert out.shape == x.shape
assert hidden_out.shape == h.shape
```

## Citation

```bibtex
@inproceedings{Melis2020Mogrifier,
title={Mogrifier LSTM},
author={Gábor Melis and Tomáš Kočiský and Phil Blunsom},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=SJe5P6EYvS}
title = {Mogrifier LSTM},
author = {Gábor Melis and Tomáš Kočiský and Phil Blunsom},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=SJe5P6EYvS}
}
```

Expand Down
96 changes: 75 additions & 21 deletions mogrifier/mogrifier.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,92 @@
from __future__ import annotations

import torch
from torch import nn
from torch import nn, Tensor
from torch.nn import Module

from einops import pack, unpack

# constants

Linear = nn.Linear

def weight(dim_in, dim_out, factorize_k = None):
if factorize_k is None:
return nn.Linear(dim_in, dim_out, bias = False)
def exists(v):
return v is not None

assert factorize_k < dim_in and factorize_k < dim_out, 'k must be of relative lower rank'
def default(v, d):
return v if exists(v) else d

# maybe factorized projection

def weight(
dim_in,
dim_out,
k: int | None = None
):
if not exists(k):
return Linear(dim_in, dim_out)

assert k < dim_in and k < dim_out, 'k must be of relative lower rank'

return nn.Sequential(
nn.Linear(dim_in, factorize_k, bias = False),
nn.Linear(factorize_k, dim_out, bias = False)
Linear(dim_in, k),
Linear(k, dim_out)
)

class Mogrifier(nn.Module):
def __init__(self, dim, iters = 5, factorize_k = None):
# main class

class Mogrifier(Module):
def __init__(
self,
dim: int,
iters = 5,
factorize_k: int | None = None,
dim_hidden: int | None = None,
hidden_factorize_k: int | None = None
):
super().__init__()
assert iters > 1

self.dim = dim

dim_hidden = default(dim_hidden, dim)
self.dim_hidden = dim_hidden

self.iters = iters

self.Q = weight(dim, dim, factorize_k)
self.R = weight(dim, dim, factorize_k) if iters > 1 else None
self.Q = nn.Sequential(
weight(dim_hidden, dim, factorize_k),
nn.Sigmoid()
)

factorize_k = default(hidden_factorize_k, factorize_k)

def forward(self, x, h):
shape = x.shape
*_, dim = shape
assert dim == self.dim, f'mogrifier accepts a dimension of {self.dim}'
self.R = nn.Sequential(
weight(dim, dim_hidden, factorize_k),
nn.Sigmoid()
)

x, h = map(lambda t: t.reshape(-1, dim), (x, h))
def forward(
self,
inputs: Tensor,
hiddens: Tensor,
iters: int | None = None
):
iters = default(iters, self.iters)

assert inputs.shape[-1] == self.dim
assert hiddens.shape[-1] == self.dim_hidden
assert inputs.shape[:-2] == hiddens.shape[:-2]

(inputs, packed_shape), (hiddens, _) = tuple(pack([t], '* d') for t in (inputs, hiddens))

for ind in range(self.iters):
if (ind % 2) == 0:
x = 2 * self.Q(h).sigmoid() * x
is_even = (ind % 2) == 0

if is_even:
inputs = 2 * self.Q(hiddens) * inputs
else:
h = 2 * self.R(x).sigmoid() * h
hiddens = 2 * self.R(inputs) * hiddens

x, h = map(lambda t: t.reshape(*shape), (x, h))
return x, h
inputs, hiddens = tuple(unpack(t, packed_shape, '* d')[0] for t in (inputs, hiddens))
return inputs, hiddens
44 changes: 25 additions & 19 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from setuptools import setup, find_packages

setup(
name = 'mogrifier',
packages = find_packages(),
version = '0.0.3',
license='MIT',
description = 'Implementation of Mogrifier circuit from Deepmind',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/mogrifier',
keywords = ['artificial intelligence', 'natural language processing'],
install_requires=[
'torch'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
name = 'mogrifier',
packages = find_packages(),
version = '0.0.4',
license='MIT',
description = 'Implementation of Mogrifier circuit from Deepmind',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/mogrifier',
keywords = [
'artificial intelligence',
'natural language processing',
'improved conditioning'
],
install_requires=[
'einops>=0.8',
'torch'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)

0 comments on commit 88b2be4

Please sign in to comment.