torchortho
is a PyTorch library for learnable activation functions based on:
- Hermite Polynomials 🧙♂️
- Fourier Series 〰
- Tropical Polynomials & Rational Functions 🌴
These adaptive activations dynamically adjust during training, offering improved expressivity, better gradient flow, and enhanced generalization for vision and language models.
This library is based on the paper:
📄 Learnable Polynomial, Trigonometric, and Tropical Activations (Khalfaoui-Hassani & Kesselheim, 2025).
For experimental results, check our repos:
- Vision models (ConvNeXt with
torchortho
activations): 🔗 GitHub - Language models (GPT-2 with
torchortho
activations): 🔗 GitHub
Install from PyPI:
pip install torchortho
or install directly from GitHub:
pip install git+https://github.com/K-H-Ismail/torchortho.git
You can use torchortho
activations just like any other PyTorch activation:
import torch
from torchortho import HermiteActivation
# Define a learnable Hermite activation
degree = 5
activation = HermiteActivation(degree)
# Forward pass
x = torch.rand(7, 4, 3, 2)
y = activation(x)
# Compute gradients
loss = y.sum()
loss.backward()
print("Gradients of activation coefficients:", activation.coefficients.grad)
print("Output:", y)
import torch
import torch.nn as nn
from torchortho import FourierActivation
class CustomMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.activation = FourierActivation(degree=4) # Learnable Fourier activation
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
return self.fc2(x)
# Initialize the model
model = CustomMLP(input_dim=10, hidden_dim=32, output_dim=1)
x = torch.randn(5, 10)
output = model(x)
print("Model Output:", output)
Unlike static activations (ReLU, GELU), torchortho
functions dynamically adapt during training, allowing models to learn optimal activation functions for different tasks.
Activation Type | Strengths |
---|---|
Hermite Activation | Adaptive polynomial approximation, variance-preserving, smooth optimization |
Fourier Activation | Captures periodic structures in data (useful for NLP, physics-based models, and time-series) |
Tropical Polynomial Activation | Convex activation for structured learning (e.g., decision boundaries, optimization landscapes) |
Rational Activation | Generalizes standard activation functions (e.g., Pade approximants can replicate ReLU, GELU, or even SwiGLU) |
- Better function approximation → Increases expressivity for deep networks.
- Variance-preserving initialization → Ensures stable training, avoiding vanishing/exploding gradients.
- More flexible than ReLU/SwiGLU → Adapts activation behavior based on data.
The effectiveness of torchortho
activations has been validated on large-scale deep learning benchmarks:
✅ Image Classification (ConvNeXt-T on ImageNet-1K)
- Replacing GELU with
torchortho
activations improves top-1 accuracy.
✅ Language Modeling (GPT-2 on OpenWebText)
- Learnable activations reduce perplexity compared to GELU-based models.
For full benchmarks, see:
This project is licensed under the GPL-3.0 License. See LICENSE for details.
We welcome contributions! Feel free to submit issues, open PRs, or suggest improvements.
For questions or collaborations, reach out via GitHub Issues.
If you use torchortho
in your research, please cite the following paper:
@article{khalfaoui2025learnable,
title={Learnable polynomial, trigonometric, and tropical activations},
author={Khalfaoui-Hassani, Ismail and Kesselheim, Stefan},
journal={arXiv preprint arXiv:2502.01247},
year={2025}
}