Skip to content

Commit

Permalink
Merge pull request #48 from khalil-research/PG
Browse files Browse the repository at this point in the history
A different version of PG
  • Loading branch information
LucasBoTang authored Dec 20, 2024
2 parents 02c4f80 + 078e1c7 commit 14df3bc
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ To reproduce the experiments in the original paper, please use the code and foll

## Features

- Implement **SPO+** [[1]](https://doi.org/10.1287/mnsc.2020.3922), **DBB** [[3]](https://arxiv.org/abs/1912.02175), **NID** [[7]](https://arxiv.org/abs/2205.15213), **DPO** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **PFYL** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **NCE** [[5]](https://www.ijcai.org/proceedings/2021/390) and **LTR** [[6]](https://proceedings.mlr.press/v162/mandi22a.htm), **I-MLE** [[8]](https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html), and **AI-MLE** [[9]](https://ojs.aaai.org/index.php/AAAI/article/view/26103).
- Implement **SPO+** [[1]](https://doi.org/10.1287/mnsc.2020.3922), **DBB** [[3]](https://arxiv.org/abs/1912.02175), **NID** [[7]](https://arxiv.org/abs/2205.15213), **DPO** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **PFY** [[4]](https://papers.nips.cc/paper/2020/hash/6bb56208f672af0dd65451f869fedfd9-Abstract.html), **NCE** [[5]](https://www.ijcai.org/proceedings/2021/390) and **LTR** [[6]](https://proceedings.mlr.press/v162/mandi22a.htm), **I-MLE** [[8]](https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html), **AI-MLE** [[9]](https://ojs.aaai.org/index.php/AAAI/article/view/26103), and **PG** [[11]](https://arxiv.org/abs/2402.03256).
- Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), and [Pyomo](http://www.pyomo.org/) API
- Support Parallel computing for optimization solver
- Support solution caching [[5]](https://www.ijcai.org/proceedings/2021/390) to speed up training
Expand Down Expand Up @@ -221,3 +221,4 @@ if __name__ == "__main__":
* [8] [Niepert, M., Minervini, P., & Franceschi, L. (2021). Implicit MLE: backpropagating through discrete exponential family distributions. Advances in Neural Information Processing Systems, 34, 14567-14579.](https://proceedings.neurips.cc/paper_files/paper/2021/hash/7a430339c10c642c4b2251756fd1b484-Abstract.html)
* [9] [Minervini, P., Franceschi, L., & Niepert, M. (2023, June). Adaptive perturbation-based gradient estimation for discrete latent variable models. In Proceedings of the AAAI Conference on Artificial Intelligence (Vol. 37, No. 8, pp. 9200-9208).](https://ojs.aaai.org/index.php/AAAI/article/view/26103)
* [10] [Schutte, N., Postek, K., & Yorke-Smith, N. (2023). Robust Losses for Decision-Focused Learning. arXiv preprint arXiv:2310.04328.](https://arxiv.org/abs/2310.04328)
* [11] [Gupta, V., & Huang, M. (2024). Decision-Focused Learning with Directional Gradients. Training, 50(100), 150.](https://arxiv.org/abs/2402.03256)
160 changes: 160 additions & 0 deletions pkg/pyepo/func/pgloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# -*- coding: utf-8 -*-
"""Untitled0.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/16A6ZqQEV37NcUljQfpR-97nSjFYD392h
"""

#############################################
# Define PG Loss
#############################################

import torch
from torch.autograd import Function
import numpy as np

class PGLossFunction(Function):
"""
A custom autograd function for Policy Gradient (PG) Loss.
Supports "PGB" (Backward Difference), "PGC" (Central Difference), and "PGF" (Forward Difference) variants.
"""

@staticmethod
def forward(ctx, pred_cost, true_cost, mode, h, optmodel):
"""
Forward pass for PG Loss.
Args:
pred_cost (torch.Tensor): Predicted cost vector (batch_size, num_vars).
true_cost (torch.Tensor): True cost vector (batch_size, num_vars).
mode (str): "PGB" for backward difference or "PGC" for central difference.
h (float): Perturbation step size.
optmodel (object): Optimization model for solving the decision problem.
Returns:
torch.Tensor: Batch-wise PG loss.
"""
device = pred_cost.device
batch_size = pred_cost.size(0)
loss = []

# Detach and convert tensors to numpy arrays for optimization
cp = pred_cost.detach().cpu().numpy()
c = true_cost.detach().cpu().numpy()

for i in range(batch_size):
c_hat = cp[i]
c_true = c[i]

if mode == "PGB":
# Backward difference
optmodel.setObj(c_hat)
sol1, obj1 = optmodel.solve()
V_hat = np.dot(sol1, c_true)

optmodel.setObj(c_hat - h * c_true)
sol2, obj2 = optmodel.solve()
V_hat_minus = np.dot(sol2, c_true)

loss.append((obj1 - obj2) / h)

elif mode == "PGC":
# Central difference
optmodel.setObj(c_hat + h * c_true)
sol1, obj1 = optmodel.solve()
V_hat_plus = np.dot(sol1, c_true)

optmodel.setObj(c_hat - h * c_true)
sol2, obj2 = optmodel.solve()
V_hat_minus = np.dot(sol2, c_true)

loss.append((obj1 - obj2) / (2 * h))

elif mode == "PGF":
# Forward difference
optmodel.setObj(c_hat + h * c_true)
sol1, obj1 = optmodel.solve()
V_hat_plus = np.dot(sol1, c_true)

optmodel.setObj(c_hat)
sol2, obj2 = optmodel.solve()
V_hat = np.dot(sol2, c_true) # NOTICE!!! extimated solution * ture cost or estimated objective func.

loss.append((obj1 - obj2) / h)

else:
raise ValueError(f"Unknown mode: {mode}")

# Convert loss to tensor and save necessary variables for backward pass
loss = torch.FloatTensor(loss).to(device)
sol1 = torch.FloatTensor(sol1).to(device)
sol2 = torch.FloatTensor(sol2).to(device)

ctx.save_for_backward(sol1, sol2)
ctx.optmodel = optmodel
ctx.h = h
ctx.mode = mode

return loss

@staticmethod
def backward(ctx, grad_output):
"""
Backward pass for PG Loss.
Args:
grad_output (torch.Tensor): Gradient of the loss with respect to its output.
Returns:
Gradients of the loss with respect to inputs: (pred_cost, true_cost, mode, h, optmodel).
"""
sol1, sol2 = ctx.saved_tensors
h = ctx.h
mode = ctx.mode
optmodel = ctx.optmodel

if mode == "PGB":
grad = (sol1 - sol2) / h
elif mode == "PGC":
grad = (sol1 - sol2) / (2 * h)
elif mode == "PGF":
grad = (sol1 - sol2) / h

grad = grad.unsqueeze(0)

return grad_output * grad, None, None, None, None


class PGLoss(torch.nn.Module):
"""
A PyTorch module for PG Loss.
Args:
optmodel: Optimization model for solving the decision problem.
mode (str): "PGB" for backward difference or "PGC" for central difference.
h (float): Perturbation step size.
"""

def __init__(self, optmodel, mode, h=0.01):
super(PGLoss, self).__init__()
self.optmodel = optmodel
self.mode = mode
self.h = h
self.pg = PGLossFunction()

def forward(self, pred_cost, true_cost):
"""
Compute the PG Loss.
Args:
pred_cost (torch.Tensor): Predicted cost vector (batch_size, num_vars).
true_cost (torch.Tensor): True cost vector (batch_size, num_vars).
Returns:
torch.Tensor: PG loss for the batch.
"""
loss = self.pg.apply(pred_cost, true_cost, self.mode, self.h, self.optmodel)
return loss

0 comments on commit 14df3bc

Please sign in to comment.