Skip to content

Commit

Permalink
revert back to block matrix inversion with better fallback mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Feb 1, 2024
1 parent e76b59b commit 9c9b94b
Showing 1 changed file with 67 additions and 48 deletions.
115 changes: 67 additions & 48 deletions src/globopt/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def _cdist_and_inverse_quadratic_kernel(

@trace((torch.rand(5, 4, 3), torch.rand(5, 4, 1), torch.rand(()), torch.rand(())))
def _rbf_fit(
X: Tensor, Y: Tensor, eps: Tensor, eig_tol: Tensor
) -> tuple[Tensor, Tensor, Tensor]:
X: Tensor, Y: Tensor, eps: Tensor, svd_tol: Tensor
) -> tuple[Tensor, Tensor]:
"""Fits the RBF regression model to the training data."""
_, M = _cdist_and_inverse_quadratic_kernel(X, X, eps)
eigvals, eigvecs = torch.linalg.eigh(M)
eigvals_thresholded = eigvals.where(eigvals.abs() > eig_tol, torch.inf)
Minv = (eigvecs / eigvals_thresholded.unsqueeze(-2)) @ eigvecs.mT
U, S, VT = torch.linalg.svd(M)
S = S.where(S >= svd_tol, torch.inf)
Minv = (VT.mT / S.unsqueeze(-2)) @ U.mT
coeffs = Minv.matmul(Y)
return eigvals, eigvecs, coeffs
return Minv, coeffs


@script( # unable to trace this one
Expand All @@ -107,40 +107,63 @@ def _rbf_fit(
torch.rand(5, 4, 1),
torch.rand(()),
torch.rand(()),
torch.rand(5, 2),
torch.rand(5, 2, 2),
torch.rand(5, 2, 1),
)
]
)
def _rbf_partial_fit(
X: Tensor,
Y: Tensor,
eps: Tensor,
eig_tol: Tensor,
eigvals: Tensor,
eigvecs: Tensor,
coeffs: Tensor,
) -> tuple[Tensor, Tensor, Tensor]:
X: Tensor, Y: Tensor, eps: Tensor, svd_tol: Tensor, Minv: Tensor, coeffs: Tensor
) -> tuple[Tensor, Tensor]:
"""Fits the given RBF regression to the new training data."""
n = coeffs.shape[-2] # index of the first new data point onwards
X_new = X[..., n:, :]
_, Phi_and_phi = _cdist_and_inverse_quadratic_kernel(X_new, X, eps)
PhiT = Phi_and_phi[..., :n]
phi = Phi_and_phi[..., n:]
prod = PhiT @ eigvecs
eigvals_mat = eigvals.expand(prod.mT.shape[:-1]).diag_embed()
M_proxy_new = torch.cat(
(torch.cat((eigvals_mat, prod.mT), -1), torch.cat((prod, phi), -1)), -2
)
eigvals_new, eigvecs_tmp = torch.linalg.eigh(M_proxy_new)
eigvecs_new = torch.cat(
(eigvecs @ eigvecs_tmp[..., :n, :], eigvecs_tmp[..., n:, :]), -2
)
eigvals_thresholded = eigvals_new.where(eigvals_new.abs() > eig_tol, torch.inf)
Minv_new = (eigvecs_new / eigvals_thresholded.unsqueeze(-2)) @ eigvecs_new.mT

# compute the new inverse of the kernel matrix via block matrix inversion and, where
# it fails, from scratch
Phi = PhiT.mT
L = Minv @ Phi
C = phi - PhiT @ L # Schur complement
C_det = torch.linalg.det(C)
mask = torch.logical_or(C_det >= 1e-17, C_det <= -1e-17) # magic number
if mask.all().item():
# all schur batches are invertible
Cinv = torch.linalg.inv(C)
B = -L @ Cinv
A = Minv - B @ L.mT
Minv_new = torch.cat((torch.cat((A, B), -1), torch.cat((B.mT, Cinv), -1)), -2)
else:
# some of the schur batches were not invertible. For those that were invertible,
# use block matrix inversion. For those that were not, compute the full inverse
Minv_new_shape = X.shape[:-2] + (X.shape[-2], X.shape[-2])
Minv_new = torch.empty(Minv_new_shape, device=Minv.device, dtype=Minv.dtype)

Cinv = torch.linalg.inv(C[mask, :, :])
L = L[mask, :, :]
B = -L @ Cinv
if Minv.ndim < Minv_new.ndim:
Minv = Minv.expand(X.shape[:-2] + (n, n))
A = Minv[mask, :, :] - B @ L.mT
Minv_new[mask, :, :] = torch.cat(
(torch.cat((A, B), -1), torch.cat((B.mT, Cinv), -1)), -2
)

not_mask = ~mask
X_old_nm = X[..., :n, :][not_mask, :, :]
_, M = _cdist_and_inverse_quadratic_kernel(X_old_nm, X_old_nm, eps)
M_new = torch.cat(
(torch.cat((M, Phi[not_mask, :, :]), -1), Phi_and_phi[not_mask, :, :]), -2
)
U_new, S_new, VT_new = torch.linalg.svd(M_new)
S_new = S_new.where(S_new >= svd_tol, torch.inf)
Minv_new[not_mask, :, :] = (VT_new.mT / S_new.unsqueeze(-2)) @ U_new.mT

# finally, compute the new coefficients
coeffs_new = Minv_new @ Y
return eigvals_new, eigvecs_new, coeffs_new
return Minv_new, coeffs_new


@trace(
Expand Down Expand Up @@ -271,7 +294,7 @@ def __init__(
train_X: Tensor,
train_Y: Tensor,
eps: Union[float, Tensor] = 1.0,
eig_tol: Union[float, Tensor] = 1e-8,
svd_tol: Union[float, Tensor] = 1e-8,
init_state: Optional[tuple[Tensor, Tensor]] = None,
) -> None:
"""Instantiates an RBF regression model for Global Optimization.
Expand All @@ -287,41 +310,37 @@ def __init__(
`train_X` points.
eps : float, optional
Distance-scaling parameter for the RBF kernel, by default `1.0`.
eig_tol : float, optional
svd_tol : float, optional
Tolerance for singular value decomposition for inversion, by default `1e-8`.
init_state : tuple of 3 Tensors, optional
Initial state of the regressor, in case of previous partial fitting, made up
of the eigendecomposition of the previous kernel distance matrix. This is a
tuple of
- `vals (b0 x b1 x ...) x m'`: the eigenvalues of the kernel matrix
- `vec (b0 x b1 x ...) x m' x m'`: the eigenvectors of the kernel matrix
of the inverse of previous kernel distance matrix and the RBF coefficients.
This is a tuple of
- `Minv (b0 x b1 x ...) x m' x m'`
- `coeffs (b0 x b1 x ...) x m' x 1`
where `m'` are the number of training points in the previous fitting.
By default `None`, in which case the model is fit anew to the training data.
"""
super().__init__(train_X, train_Y)
eps = torch.scalar_tensor(eps)
eig_tol = torch.scalar_tensor(eig_tol)
svd_tol = torch.scalar_tensor(svd_tol)
if init_state is None:
eigvals, eigvecs, coeffs = _rbf_fit(
self.train_X, self.train_Y, eps, eig_tol
)
Minv, coeffs = _rbf_fit(self.train_X, self.train_Y, eps, svd_tol)
else:
eigvals, eigvecs, coeffs = _rbf_partial_fit(
self.train_X, self.train_Y, eps, eig_tol, *init_state
Minv, coeffs = _rbf_partial_fit(
self.train_X, self.train_Y, eps, svd_tol, *init_state
)
self.register_buffer("eps", eps)
self.register_buffer("eig_tol", eig_tol)
self.register_buffer("eigvals", eigvals)
self.register_buffer("eigvecs", eigvecs)
self.register_buffer("svd_tol", svd_tol)
self.register_buffer("Minv", Minv)
self.register_buffer("coeffs", coeffs)
self.to(train_X)

@property
def state(self) -> tuple[Tensor, Tensor, Tensor]:
"""State of a fitted RBF regressor, i.e., the products of the eigendecomposition
of the kernel matrix and coefficients. Use this to partially fit a new regressor
(see `__init__`)"""
return self.eigvals, self.eigvecs, self.coeffs
def state(self) -> tuple[Tensor, Tensor]:
"""State of a fitted RBF regressor, i.e., the inverse of the kernel matrix and
coefficients. Use this to partially fit a new regressor (see `__init__`)"""
return self.Minv, self.coeffs

def forward(self, X: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Computes the RBF regression model.
Expand Down Expand Up @@ -349,4 +368,4 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **_: Any) -> "Rbf":
train_X, train_Y = self._prepare_for_fantasizing(X, Y)
Xnew = torch.cat((train_X, X), dim=-2)
Ynew = torch.cat((train_Y, Y), dim=-2)
return Rbf(Xnew, Ynew, self.eps, self.eig_tol, self.state)
return Rbf(Xnew, Ynew, self.eps, self.svd_tol, self.state)

0 comments on commit 9c9b94b

Please sign in to comment.