Skip to content

Commit

Permalink
feat: support pytorch in engine
Browse files Browse the repository at this point in the history
  • Loading branch information
gavincyi committed Sep 26, 2023
1 parent ab767a2 commit 1d9eb07
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
30 changes: 19 additions & 11 deletions src/fpm_risk_model/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

_BACKEND_ENGINE = "numpy"
_SUPPORTED_ENGINES = ["numpy", "tensorflow", "cupy", "jax"]
_SUPPORTED_ENGINES = ["numpy", "tensorflow", "cupy", "jax", "torch"]


def backend():
Expand All @@ -19,13 +19,13 @@ def set_backend(library_name):
----------
library_name : str
Library name. Default is `numpy`. Options are `numpy`, `tensorflow`,
`cupy` and `jax`.
`cupy`, `jax` and `torch`.
"""
library_name = library_name.lower()
if library_name not in _SUPPORTED_ENGINES:
raise ValueError(
"Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not "
f"{library_name}"
"Only `numpy`, `tensorflow`, `cupy`, `jax` and `torch` are supported, "
f"but not {library_name}"
)
global _BACKEND_ENGINE
_BACKEND_ENGINE = library_name
Expand All @@ -44,22 +44,18 @@ def use_backend(library_name="numpy"):
----------
library_name : str
Library name. Default is `numpy`. Options are `numpy`, `tensorflow`,
`cupy` and `jax`.
`cupy`, `jax` and `torch`.
"""
library_name = library_name.lower()
if library_name not in _SUPPORTED_ENGINES:
raise ValueError(
"Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not "
f"{library_name}"
"Only `numpy`, `tensorflow`, `cupy`, `jax` and `torch` are supported, "
f"but not {library_name}"
)
global _BACKEND_ENGINE
_original = _BACKEND_ENGINE
try:
_BACKEND_ENGINE = library_name
if _BACKEND_ENGINE == "tensorflow":
import tensorflow.experimental.numpy as np

np.experimental_enable_numpy_behavior()
yield
finally:
_BACKEND_ENGINE = _original
Expand All @@ -79,10 +75,17 @@ def __getattribute__(self, __name: str) -> Any:
import numpy as anp
elif _BACKEND_ENGINE == "tensorflow":
import tensorflow.experimental.numpy as anp

anp.experimental_enable_numpy_behavior()
elif _BACKEND_ENGINE == "cupy":
import cupy as anp
elif _BACKEND_ENGINE == "jax":
import jax.numpy as anp
elif _BACKEND_ENGINE == "torch":
import torch as anp

anp.array = anp.tensor
anp.ndarray = anp.Tensor
else:
raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}")
except ImportError:
Expand Down Expand Up @@ -114,11 +117,16 @@ def __getattribute__(self, __name: str) -> Any:
if _BACKEND_ENGINE == "numpy":
import numpy.linalg as alinalg
elif _BACKEND_ENGINE == "tensorflow":
import tensorflow.experimental.numpy as anp
import tensorflow.linalg as alinalg

anp.experimental_enable_numpy_behavior()
elif _BACKEND_ENGINE == "cupy":
import cupy.linalg as alinalg
elif _BACKEND_ENGINE == "jax":
import jax.numpy.linalg as alinalg
elif _BACKEND_ENGINE == "torch":
import torch.linalg as alinalg
else:
raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}")
except ImportError:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fpm_risk_model.engine import LinAlgEngine, NumpyEngine, backend, use_backend


def test_use_backend():
assert backend() == "numpy"
with use_backend("tensorflow"):
assert backend() == "tensorflow"


def test_numpy_engine():
import numpy

returns = (numpy.random.rand(100, 20) - 0.5) / 10
with use_backend("numpy"):
np = NumpyEngine()
linalg = LinAlgEngine()
returns = np.array(returns)
mu = np.mean(returns, axis=0)
demean = returns - mu
cov = demean.T @ demean
invcov = linalg.inv(cov)
assert isinstance(invcov, np.ndarray)

0 comments on commit 1d9eb07

Please sign in to comment.