Skip to content

Commit

Permalink
add optional clf-cbf-qp denoising
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoanlu committed May 7, 2024
1 parent 78a6718 commit 77e1216
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The policy takes 1) the latest N step of observation $o_t$ (position and velocit

### Deviation from the original implementation
- Add a linear layer before the Mish activation to the condition encoder of `ConditionalResidualBlock1D`. This is to prevent the activation from truncating large negative values from the normalized observation.
- A CLF-CBF-QP controller is implemented and used to modify the noisy actions during the denoising process of the policy. By default, this controller is not used.

<img src="assets/df_clf_cbf_comp.jpg" alt="drawing" width="600"/>


## References
Expand Down
Binary file added assets/df_clf_cbf_comp.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ obs_horizon: 2
action_horizon: 10

controller:
common:
sampling_time: 0.1 # sec
networks:
obs_dim: 6
action_dim: 6
Expand All @@ -27,6 +29,13 @@ controller:
prediction_type: "epsilon"
use_karras_sigmas: true

cbf_clf_controller:
denoising_guidance_step: 20
cbf_alpha: 10.0
clf_gamma: 0.03
penalty_slack_cbf: 1.0e+3
penalty_slack_clf: 1.0

trainer:
use_ema: true
batch_size: 256
Expand Down
184 changes: 184 additions & 0 deletions core/controllers/quadrotor_clf_cbf_qp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from typing import Dict, List
import numpy as np
import osqp
from scipy import sparse

from core.controllers.base_controller import BaseController


class QuadrotorCLFCBFController(BaseController):
"""
A CLF-CBF safety filter assuming a simple velocity-controled dynamics
y_dot = u1
z_dot = u2
Barrier funciton h is defined as the distances to each obstacle
"""

def __init__(self, config: Dict, device: str = "cuda"):
super().__init__(device)
self.obstacle_info = {"center": [], "radius": []}
self.set_config(config)

def predict_action(self, obs_dict: Dict[str, List], control: np.ndarray, target_position: np.ndarray) -> np.ndarray:
for center, radius in zip(obs_dict["obstacle_info"]["center"], obs_dict["obstacle_info"]["radius"]):
self.set_obstacle(center, radius)

safe_command = self.clf_cbf_control(
state=obs_dict["state"],
control=control,
obs_center=self.obstacle_info["center"],
obs_radius=self.obstacle_info["radius"],
cbf_alpha=self.cbf_alpha,
clf_gamma=self.clf_gamma,
penalty_slack_cbf=self.penalty_slack_cbf,
penalty_slack_clf=self.penalty_slack_clf,
target_position=target_position,
)
return safe_command

def set_obstacle(self, center: tuple, radius: float):
self.obstacle_info = {"center": [], "radius": []}
self.obstacle_info["center"].append(center)
self.obstacle_info["radius"].append(radius)

def set_config(self, config: Dict):
self.cbf_alpha = config["cbf_clf_controller"]["cbf_alpha"]
self.clf_gamma = config["cbf_clf_controller"]["clf_gamma"]
self.penalty_slack_cbf = config["cbf_clf_controller"]["penalty_slack_cbf"]
self.penalty_slack_clf = config["cbf_clf_controller"]["penalty_slack_clf"]
self.denoising_guidance_step = config["cbf_clf_controller"]["denoising_guidance_step"]
self.quadrotor_params = config["simulator"]

@staticmethod
def _barrier_func(y, z, obs_y, obs_z, obs_r) -> float:
return (y - obs_y) ** 2 + (z - obs_z) ** 2 - (obs_r) ** 2

@staticmethod
def _barrier_func_dot(y, z, obs_y, obs_z) -> list:
return [2 * (y - obs_y), 2 * (z - obs_z)]

@staticmethod
def _lyapunoc_func(y, z, des_y, des_z) -> float:
return (y - des_y) ** 2 + (z - des_z) ** 2

@staticmethod
def _lyapunov_func_dot(y, z, des_y, des_z) -> list:
return [2 * (y - des_y), 2 * (z - des_z)]

@staticmethod
def _define_QP_problem_data(
u1: float,
u2: float,
cbf_alpha: float,
clf_gamma: float,
penalty_slack_cbf: float,
penalty_slack_clf: float,
h: list,
coeffs_dhdx: list,
v: list,
coeffs_dvdx: list,
vmin=-15.0,
vmax=15.0,
):
vmin, vmax = -15.0, 15.0

P = sparse.csc_matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, penalty_slack_cbf, 0], [0, 0, 0, penalty_slack_clf]])
q = np.array([-u1, -u2, 0, 0])
A = sparse.csc_matrix(
[c for c in coeffs_dhdx]
+ [c for c in coeffs_dvdx]
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
)
lb = np.array([-cbf_alpha * h_ for h_ in h] + [-np.inf for _ in v] + [vmin, vmin, 0, 0])
ub = np.array([np.inf for _ in h] + [-clf_gamma * v_ for v_ in v] + [vmax, vmax, np.inf, np.inf])
return P, q, A, lb, ub

@staticmethod
def _get_quadrotor_state(state):
y, y_dot, z, z_dot, phi, phi_dot = state
return y, y_dot, z, z_dot, phi, phi_dot

def _calculate_cbf_coeffs(self, state: np.ndarray, obs_center: List, obs_radius: List, minimal_distance: float):
"""
Let barrier function be h and system state x, the CBF constraint
h_dot(x) >= - alpha * h + δ
"""
h = [] # barrier values (here, remaining distance to each obstacle)
coeffs_dhdx = [] # dhdt = dhdx * dxdt = dhdx * u
for center, radius in zip(obs_center, obs_radius):
y, _, z, _, _, _ = self._get_quadrotor_state(state)
h.append(self._barrier_func(y, z, center[0], center[1], radius + minimal_distance))
# Additional [1, 0] incorporates the CBF slack variable into the constraint
coeffs_dhdx.append(self._barrier_func_dot(y, z, center[0], center[1]) + [1, 0])
return h, coeffs_dhdx

def _calculate_clf_coeffs(self, state: np.ndarray, target_y: float, _target_z: float):
"""
Let Lyapunov function be v and system state x, the CBF constraint
v_dot(x) - δ <= - gamma * v
"""
y, _, z, _, _, _ = self._get_quadrotor_state(state)
v = [self._lyapunoc_func(y, z, target_y, _target_z)]
# Additional [0, -1] incorporates the CLF slack variable into the constraint
coeffs_dvdx = [self._lyapunov_func_dot(y, z, target_y, _target_z) + [0, -1]]
return v, coeffs_dvdx

def clf_cbf_control(
self,
state: np.ndarray,
control: np.ndarray,
obs_center: List,
obs_radius: List,
cbf_alpha: float = 15.0,
clf_gamma: float = 0.01,
penalty_slack_cbf: float = 1e2,
penalty_slack_clf: float = 1.0,
target_position: tuple = (5.0, 5.0),
):
"""
Calculate the safe command by solveing the following optimization problem
minimize || u - u_nom ||^2 + k * δ^2
u, δ
s.t.
h'(x) ≥ -𝛼 * h(x) - δ1
v'(x) ≤ -γ * v(x) + δ2
u_min ≤ u ≤ u_max
0 ≤ δ1,δ2 ≤ inf
where
u = [ux, uy] is the control input in x and y axis respectively.
δ is the slack variable
h(x) is the control barrier function and h'(x) its derivative
v(x) is the lyapunov function and v'(x) its derivative
The problem above can be formulated as QP (ref: https://osqp.org/docs/solver/index.html)
minimize 1/2 * x^T * Px + q^T x
x
s.t.
l ≤ Ax ≤ u
where
x = [ux, uy, δ1, δ2]
"""
u1, u2 = control
target_y, target_z = target_position

# Calculate values of the barrier function and coeffs in h_dot to state
h, coeffs_dhdx = self._calculate_cbf_coeffs(state, obs_center, obs_radius, self.quadrotor_params["l_q"])
# Calculate value of the lyapunov function and coeffs in v_dot to state
v, coeffs_dvdx = self._calculate_clf_coeffs(state, target_y, target_z)

# Define problem
P, q, A, lb, ub = self._define_QP_problem_data(
u1, u2, cbf_alpha, clf_gamma, penalty_slack_cbf, penalty_slack_clf, h, coeffs_dhdx, v, coeffs_dvdx
)

# Solve QP
prob = osqp.OSQP()
prob.setup(P, q, A, lb, ub, verbose=False, time_limit=0)
# Solve QP problem
res = prob.solve()

safe_u1, safe_u2, _, _ = res.x
return np.array([safe_u1, safe_u2])
48 changes: 47 additions & 1 deletion core/controllers/quadrotor_diffusion_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
DPMSolverMultistepScheduler,
)

from core.controllers.base_controller import BaseController
from core.controllers.quadrotor_clf_cbf_qp import QuadrotorCLFCBFController
from core.networks.conditional_unet1d import ConditionalUnet1D
from utils.normalizers import BaseNormalizer

Expand Down Expand Up @@ -47,12 +49,13 @@ def build_noise_scheduler_from_config(config: Dict):
raise NotImplementedError


class QuadrotorDiffusionPolicy:
class QuadrotorDiffusionPolicy(BaseController):
def __init__(
self,
model: ConditionalUnet1D,
noise_scheduler: DDPMScheduler,
normalizer: BaseNormalizer,
clf_cbf_controller: QuadrotorCLFCBFController,
config: Dict,
device: str = "cuda",
):
Expand All @@ -64,6 +67,9 @@ def __init__(
self.set_config(config)
self.net.to(self.device)

self.clf_cbf_controller = clf_cbf_controller
self.use_clf_cbf_guidance = False if clf_cbf_controller is None else True

def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray:
# stack the observations
obs_seq = np.stack(obs_dict["state"])
Expand Down Expand Up @@ -94,6 +100,25 @@ def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray:
# inverse diffusion step (remove noise)
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample

if self.use_clf_cbf_guidance:
diffusing_action = self.normalizer.unnormalize_data(
naction.detach().to("cpu").numpy().squeeze(), stats=self.norm_stats["act"]
) # (pred_horizon, 6)
if k < self.clf_cbf_controller.denoising_guidance_step:
refined_action = diffusing_action.copy()
for idx, act in enumerate(diffusing_action):
clf_cbf_obs, pred_control, target_position = self._preprocess_cbf_clf_input(
obs_dict, act, diffusing_action
)
safe_yz_velocity = self.clf_cbf_controller.predict_action(
obs_dict=clf_cbf_obs,
control=pred_control,
target_position=target_position,
)
refined_action[idx, ...] = self._calculate_refined_action_step(act, safe_yz_velocity)
naction = self.normalizer.normalize_data(np.array(refined_action), stats=self.norm_stats["act"])
naction = torch.from_numpy(naction).to(self.device, dtype=torch.float32).unsqueeze(0)

# unnormalize action
naction = naction.detach().to("cpu").numpy()
# (1, pred_horizon, action_dim)
Expand All @@ -116,6 +141,7 @@ def set_config(self, config: Dict):
self.action_horizon = config["action_horizon"]
self.pred_horizon = config["pred_horizon"]
self.action_dim = config["controller"]["networks"]["action_dim"]
self.sampling_time = config["controller"]["common"]["sampling_time"]
self.norm_stats = {
"act": config["normalizer"]["action"],
"obs": config["normalizer"]["observation"],
Expand All @@ -136,3 +162,23 @@ def calculate_force_command(self, state: np.ndarray, ref_state: np.ndarray) -> n
zr_ddot = (zr_dot - z_dot) / dt
phir_ddot = (phir_dot - phi_dot) / dt
return np.array([m_q * (g + zr_ddot), I_xx * phir_ddot])

def _preprocess_cbf_clf_input(
self, obs_dict: Dict[str, List], pred_action: np.ndarray, diffusing_action: np.ndarray
):
pred_state = pred_action
pred_control = pred_action[[1, 3]]
target_position_y, target_position_z = diffusing_action[-1, [0, 2]] # myoptic planning of CLF
target_position = (target_position_y, target_position_z)
obstacle_info = {"center": obs_dict["obs_center"], "radius": obs_dict["obs_radius"]}
return {"state": pred_state, "obstacle_info": obstacle_info}, pred_control, target_position

def _calculate_refined_action_step(self, pred_act, safe_yz_velocity):
refined_step_action = pred_act.copy()
refined_step_action[0] += safe_yz_velocity[0] * self.sampling_time
refined_step_action[2] += safe_yz_velocity[1] * self.sampling_time
refined_step_action[1] = safe_yz_velocity[0]
refined_step_action[3] = safe_yz_velocity[1]
refined_step_action[4] = -np.arctan(safe_yz_velocity[0] / safe_yz_velocity[1])
# refinedstep_action[5] = (refinedstep_action[4] - pred_act[4]) / self.sampling_time
return refined_step_action
6 changes: 5 additions & 1 deletion demo.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def test_iter(self):

# batch context matches expectecd shapes
batch = next(iter(dataloader))
self.assertEqual(
batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim)
)
self.assertEqual(batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim))
self.assertEqual(batch["action"].shape, (batch_size, self.pred_horizon, self.action_dim))


Expand Down

0 comments on commit 77e1216

Please sign in to comment.