diff --git a/src/fpm_risk_model/engine.py b/src/fpm_risk_model/engine.py index 57fc074..ae35408 100644 --- a/src/fpm_risk_model/engine.py +++ b/src/fpm_risk_model/engine.py @@ -2,7 +2,7 @@ from typing import Any _BACKEND_ENGINE = "numpy" -_SUPPORTED_ENGINES = ["numpy", "tensorflow", "cupy", "jax"] +_SUPPORTED_ENGINES = ["numpy", "tensorflow", "cupy", "jax", "torch"] def backend(): @@ -19,13 +19,13 @@ def set_backend(library_name): ---------- library_name : str Library name. Default is `numpy`. Options are `numpy`, `tensorflow`, - `cupy` and `jax`. + `cupy`, `jax` and `torch`. """ library_name = library_name.lower() if library_name not in _SUPPORTED_ENGINES: raise ValueError( - "Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not " - f"{library_name}" + "Only `numpy`, `tensorflow`, `cupy`, `jax` and `torch` are supported, " + f"but not {library_name}" ) global _BACKEND_ENGINE _BACKEND_ENGINE = library_name @@ -44,22 +44,18 @@ def use_backend(library_name="numpy"): ---------- library_name : str Library name. Default is `numpy`. Options are `numpy`, `tensorflow`, - `cupy` and `jax`. + `cupy`, `jax` and `torch`. """ library_name = library_name.lower() if library_name not in _SUPPORTED_ENGINES: raise ValueError( - "Only `numpy`, `tensorflow`, `cupy` and `jax` are supported, but not " - f"{library_name}" + "Only `numpy`, `tensorflow`, `cupy`, `jax` and `torch` are supported, " + f"but not {library_name}" ) global _BACKEND_ENGINE _original = _BACKEND_ENGINE try: _BACKEND_ENGINE = library_name - if _BACKEND_ENGINE == "tensorflow": - import tensorflow.experimental.numpy as np - - np.experimental_enable_numpy_behavior() yield finally: _BACKEND_ENGINE = _original @@ -79,10 +75,17 @@ def __getattribute__(self, __name: str) -> Any: import numpy as anp elif _BACKEND_ENGINE == "tensorflow": import tensorflow.experimental.numpy as anp + + anp.experimental_enable_numpy_behavior() elif _BACKEND_ENGINE == "cupy": import cupy as anp elif _BACKEND_ENGINE == "jax": import jax.numpy as anp + elif _BACKEND_ENGINE == "torch": + import torch as anp + + anp.array = anp.tensor + anp.ndarray = anp.Tensor else: raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}") except ImportError: @@ -114,11 +117,16 @@ def __getattribute__(self, __name: str) -> Any: if _BACKEND_ENGINE == "numpy": import numpy.linalg as alinalg elif _BACKEND_ENGINE == "tensorflow": + import tensorflow.experimental.numpy as anp import tensorflow.linalg as alinalg + + anp.experimental_enable_numpy_behavior() elif _BACKEND_ENGINE == "cupy": import cupy.linalg as alinalg elif _BACKEND_ENGINE == "jax": import jax.numpy.linalg as alinalg + elif _BACKEND_ENGINE == "torch": + import torch.linalg as alinalg else: raise ValueError(f"Cannot recognize backend {_BACKEND_ENGINE}") except ImportError: diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..570739b --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,22 @@ +from fpm_risk_model.engine import LinAlgEngine, NumpyEngine, backend, use_backend + + +def test_use_backend(): + assert backend() == "numpy" + with use_backend("tensorflow"): + assert backend() == "tensorflow" + + +def test_numpy_engine(): + import numpy + + returns = (numpy.random.rand(100, 20) - 0.5) / 10 + with use_backend("numpy"): + np = NumpyEngine() + linalg = LinAlgEngine() + returns = np.array(returns) + mu = np.mean(returns, axis=0) + demean = returns - mu + cov = demean.T @ demean + invcov = linalg.inv(cov) + assert isinstance(invcov, np.ndarray)