Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Damped Lagrangian Formulation #93

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cooper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
warnings.warn("Could not retrieve cooper version!")

from cooper.constrained_optimizer import ConstrainedOptimizer
from cooper.lagrangian_formulation import LagrangianFormulation
from cooper.lagrangian_formulation import LagrangianFormulation, DampedLagrangianFormulation
from cooper.problem import CMPState, ConstrainedMinimizationProblem
from cooper.state_logger import StateLogger

Expand Down
82 changes: 82 additions & 0 deletions cooper/lagrangian_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,88 @@ def _populate_gradients(
violation_for_update.backward(inputs=dual_vars)


class DampedLagrangianFormulation(LagrangianFormulation):

"""
Provides utilities for computing the Damped-Lagrangian
proposed by :cite:t:`platt1987constrained` and associated with a
``ConstrainedMinimizationProblem`` and for populating the
gradients for the primal and dual parameters.

Args:
cmp: ``ConstrainedMinimizationProblem`` we aim to solve and which gives
rise to the Lagrangian.
damping_coefficient: Coefficient used for the damping term of the
multipliers.
ineq_init: Initialization values for the inequality multipliers.
eq_init: Initialization values for the equality multipliers.
aug_lag_coefficient: Coefficient used for the augmented Lagrangian.
"""

def __init__(
self,
cmp: ConstrainedMinimizationProblem,
damping_coefficient: float,
ineq_init: Optional[torch.Tensor] = None,
eq_init: Optional[torch.Tensor] = None,
aug_lag_coefficient: float = 0.0,
):
super().__init__(cmp, ineq_init, eq_init, aug_lag_coefficient)

if damping_coefficient <= 0:
raise ValueError("Damping coefficient must be stricly positive")

self.damping_coefficient = damping_coefficient

def weighted_violation(
self, cmp_state: CMPState, constraint_type: str
) -> torch.Tensor:
"""
Computes the dot product between the current damped multipliers and the
constraint violations of type ``constraint_type``. If proxy-constraints
are provided in the :py:class:`.CMPState`, the non-proxy (usually
non-differentiable) constraints are used for computing the dot product,
while the "proxy-constraint" dot products are stored under
``self.state_update``.

Args:
cmp_state: current ``CMPState``
constraint_type: type of constrained to be used
"""

defect = getattr(cmp_state, constraint_type + "_defect")
has_defect = defect is not None

proxy_defect = getattr(cmp_state, "proxy_" + constraint_type + "_defect")
has_proxy_defect = proxy_defect is not None

if not has_proxy_defect:
# If not given proxy constraints, then the regular defects are
# used for computing gradients and evaluating the multipliers
proxy_defect = defect

if not has_defect:
# We should always have at least the regular defects, if not, then
# the problem instance does not have `constraint_type` constraints
proxy_violation = torch.tensor([0.0], device=cmp_state.loss.device)
else:
multipliers = getattr(self, constraint_type + "_multipliers")()
# We compute the damped multipliers. We use the damping coefficient
# and the defect level to avoid oscillations
damped_multiplier = (multipliers - self.damping_coefficient * defect.detach())

# We compute (primal) gradients of this object
proxy_violation = torch.sum(damped_multiplier.detach() * proxy_defect)

# This is the violation of the "actual" constraint. We use this
# to update the value of the multipliers by lazily filling the
# multiplier gradients in `populate_gradients`
violation_for_update = torch.sum(damped_multiplier * defect.detach())
self.state_update.append(violation_for_update)

return proxy_violation


class ProxyLagrangianFormulation(BaseLagrangianFormulation):
"""
Placeholder class for the proxy-Lagrangian formulation proposed by
Expand Down
32 changes: 32 additions & 0 deletions docs/source/lagrangian_formulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,38 @@ the existence of a pure Nash equilibrium is guaranteed :cite:p:`vonNeumann1928th
.. autoclass:: LagrangianFormulation
:members:


Damped-Lagrangian Formulation
-----------------------------

The Damped-Lagrangian Formulation modifies the traditional Lagrangian approach to address oscillations that can occur in the optimization process when constraints are suddenly satisfied or violated.
This methodology, termed the *Modified Differential Method of Multipliers*, was initially proposed by John C. Platt and Alan H. Barr in 1988 :cite:`platt1987constrained`
and further explored in contemporary discussions on making machine learning algorithms more tunable, as highlighted in a recent blog post :cite:`engraved2024tunable`.

Overview
^^^^^^^^
The standard Lagrangian multipliers approach can lead to oscillatory behavior as the Lagrangian multiplier \(\lambda\) grows uncontrolled when constraints are breached.
Upon meeting the constraints, a high \(\lambda\) value can push the solution away from the constraint boundary, resulting in oscillations around the optimal solution.
The Damped-Lagrangian formulation introduces a damping mechanism to stabilize these oscillations and enhance convergence.

Theoretical Background
^^^^^^^^^^^^^^^^^^^^^^
The principal adjustment involves introducing a damping term to the update rule of \(\lambda\), analogous to damping in a physical oscillatory system,
to prevent excessive fluctuations and promote stability.

Advantages
^^^^^^^^^^
1. **Reduced Oscillations**: Introduces damping to minimize oscillatory behaviors, leading to more stable convergence.
2. **Flexibility**: Effective across both convex and concave Pareto fronts.
3. **Hyper-parameter Tuning**: The damping hyper-parameter allows for balancing quick convergence with adequate exploration of the solution space.

Considerations
^^^^^^^^^^^^^^
- **Hyper-parameter Selection**: Introduces a new hyper-parameter, the damping factor, which necessitates careful tuning to optimize convergence dynamics without altering the final solution.

.. autoclass:: DampedLagrangianFormulation
:members:

Proxy-Lagrangian Formulation
----------------------------

Expand Down
13 changes: 13 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ @inproceedings{reddi2018amsgrad
year = {2018},
url = {https://openreview.net/forum?id=r1laEnA5Ym}
}
@inproceedings{platt1987constrained,
title={Constrained differential optimization},
author={Platt, John and Barr, Alan},
booktitle={Neural Information Processing Systems},
year={1987}
}
@misc{engraved2024tunable,
title = {How We Can Make Machine Learning Algorithms Tunable},
author = {Engraved Blog},
year = {2024},
howpublished = {\url{https://www.engraved.blog/how-we-can-make-machine-learning-algorithms-tunable/}},
note = {Accessed: 2024-06-05}
}
45 changes: 44 additions & 1 deletion tests/test_lagrangian_formulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

"""Tests for Lagrangian Formulation class."""
"""Tests for Lagrangian Formulations."""

import torch

Expand Down Expand Up @@ -28,3 +28,46 @@ def closure(self):
)
lf.create_state(cmp.state)
assert (lf.ineq_multipliers is not None) and (lf.eq_multipliers is not None)


def test_damped_lagrangian_formulation():
class DummyCMP(cooper.ConstrainedMinimizationProblem):
def __init__(self):
super().__init__(is_constrained=True)

def closure(self):
pass

cmp = DummyCMP()
damping_coefficient = 10.0


# Check that constraint multipliers are created correctly
lf = cooper.DampedLagrangianFormulation(cmp, damping_coefficient)
cmp.state = cooper.CMPState(eq_defect=torch.tensor([1.0]))
lf.create_state(cmp.state)

assert (lf.ineq_multipliers is None) and (lf.eq_multipliers is not None)

# Check that the damping coefficient is set correctly
lf = cooper.DampedLagrangianFormulation(cmp, damping_coefficient)
cmp.state = cooper.CMPState(
eq_defect=torch.tensor([1.0]), ineq_defect=torch.tensor([1.0, 1.2])
)

lf.create_state(cmp.state)
assert (lf.ineq_multipliers is not None) and (lf.eq_multipliers is not None)

# Check correct violation of constraints
lf = cooper.DampedLagrangianFormulation(cmp, damping_coefficient)
cmp.state = cooper.CMPState(
eq_defect=torch.tensor([1.0]), ineq_defect=torch.tensor([1.0, 1.2])
)

lf.create_state(cmp.state)

## Check that the weighted violation is correct for equality constraints
assert torch.allclose(lf.weighted_violation(cmp.state, constraint_type="eq"), torch.tensor([-10.0]))

## Check that the weighted violation is correct for inequality constraints
assert torch.allclose(lf.weighted_violation(cmp.state, constraint_type="ineq"), torch.tensor([-24.4]))