-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #48 from khalil-research/PG
A different version of PG
- Loading branch information
Showing
2 changed files
with
162 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Untitled0.ipynb | ||
Automatically generated by Colab. | ||
Original file is located at | ||
https://colab.research.google.com/drive/16A6ZqQEV37NcUljQfpR-97nSjFYD392h | ||
""" | ||
|
||
############################################# | ||
# Define PG Loss | ||
############################################# | ||
|
||
import torch | ||
from torch.autograd import Function | ||
import numpy as np | ||
|
||
class PGLossFunction(Function): | ||
""" | ||
A custom autograd function for Policy Gradient (PG) Loss. | ||
Supports "PGB" (Backward Difference), "PGC" (Central Difference), and "PGF" (Forward Difference) variants. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, pred_cost, true_cost, mode, h, optmodel): | ||
""" | ||
Forward pass for PG Loss. | ||
Args: | ||
pred_cost (torch.Tensor): Predicted cost vector (batch_size, num_vars). | ||
true_cost (torch.Tensor): True cost vector (batch_size, num_vars). | ||
mode (str): "PGB" for backward difference or "PGC" for central difference. | ||
h (float): Perturbation step size. | ||
optmodel (object): Optimization model for solving the decision problem. | ||
Returns: | ||
torch.Tensor: Batch-wise PG loss. | ||
""" | ||
device = pred_cost.device | ||
batch_size = pred_cost.size(0) | ||
loss = [] | ||
|
||
# Detach and convert tensors to numpy arrays for optimization | ||
cp = pred_cost.detach().cpu().numpy() | ||
c = true_cost.detach().cpu().numpy() | ||
|
||
for i in range(batch_size): | ||
c_hat = cp[i] | ||
c_true = c[i] | ||
|
||
if mode == "PGB": | ||
# Backward difference | ||
optmodel.setObj(c_hat) | ||
sol1, obj1 = optmodel.solve() | ||
V_hat = np.dot(sol1, c_true) | ||
|
||
optmodel.setObj(c_hat - h * c_true) | ||
sol2, obj2 = optmodel.solve() | ||
V_hat_minus = np.dot(sol2, c_true) | ||
|
||
loss.append((obj1 - obj2) / h) | ||
|
||
elif mode == "PGC": | ||
# Central difference | ||
optmodel.setObj(c_hat + h * c_true) | ||
sol1, obj1 = optmodel.solve() | ||
V_hat_plus = np.dot(sol1, c_true) | ||
|
||
optmodel.setObj(c_hat - h * c_true) | ||
sol2, obj2 = optmodel.solve() | ||
V_hat_minus = np.dot(sol2, c_true) | ||
|
||
loss.append((obj1 - obj2) / (2 * h)) | ||
|
||
elif mode == "PGF": | ||
# Forward difference | ||
optmodel.setObj(c_hat + h * c_true) | ||
sol1, obj1 = optmodel.solve() | ||
V_hat_plus = np.dot(sol1, c_true) | ||
|
||
optmodel.setObj(c_hat) | ||
sol2, obj2 = optmodel.solve() | ||
V_hat = np.dot(sol2, c_true) # NOTICE!!! extimated solution * ture cost or estimated objective func. | ||
|
||
loss.append((obj1 - obj2) / h) | ||
|
||
else: | ||
raise ValueError(f"Unknown mode: {mode}") | ||
|
||
# Convert loss to tensor and save necessary variables for backward pass | ||
loss = torch.FloatTensor(loss).to(device) | ||
sol1 = torch.FloatTensor(sol1).to(device) | ||
sol2 = torch.FloatTensor(sol2).to(device) | ||
|
||
ctx.save_for_backward(sol1, sol2) | ||
ctx.optmodel = optmodel | ||
ctx.h = h | ||
ctx.mode = mode | ||
|
||
return loss | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
""" | ||
Backward pass for PG Loss. | ||
Args: | ||
grad_output (torch.Tensor): Gradient of the loss with respect to its output. | ||
Returns: | ||
Gradients of the loss with respect to inputs: (pred_cost, true_cost, mode, h, optmodel). | ||
""" | ||
sol1, sol2 = ctx.saved_tensors | ||
h = ctx.h | ||
mode = ctx.mode | ||
optmodel = ctx.optmodel | ||
|
||
if mode == "PGB": | ||
grad = (sol1 - sol2) / h | ||
elif mode == "PGC": | ||
grad = (sol1 - sol2) / (2 * h) | ||
elif mode == "PGF": | ||
grad = (sol1 - sol2) / h | ||
|
||
grad = grad.unsqueeze(0) | ||
|
||
return grad_output * grad, None, None, None, None | ||
|
||
|
||
class PGLoss(torch.nn.Module): | ||
""" | ||
A PyTorch module for PG Loss. | ||
Args: | ||
optmodel: Optimization model for solving the decision problem. | ||
mode (str): "PGB" for backward difference or "PGC" for central difference. | ||
h (float): Perturbation step size. | ||
""" | ||
|
||
def __init__(self, optmodel, mode, h=0.01): | ||
super(PGLoss, self).__init__() | ||
self.optmodel = optmodel | ||
self.mode = mode | ||
self.h = h | ||
self.pg = PGLossFunction() | ||
|
||
def forward(self, pred_cost, true_cost): | ||
""" | ||
Compute the PG Loss. | ||
Args: | ||
pred_cost (torch.Tensor): Predicted cost vector (batch_size, num_vars). | ||
true_cost (torch.Tensor): True cost vector (batch_size, num_vars). | ||
Returns: | ||
torch.Tensor: PG loss for the batch. | ||
""" | ||
loss = self.pg.apply(pred_cost, true_cost, self.mode, self.h, self.optmodel) | ||
return loss |