"""
Do cholesky-based GP predictions, allow for low-rank updates.
GPytorch does not support low-rank updates with Cholesky as far as I can tell.
"""
import torch
# TODO heteroskedastic GP - make sure we add noise, and noise is called correctly.
# TODO can't handle multi task GP yet - catch exception on train_X
[docs]
class GPPosteriorPredictor():
"""
A convenience class for computing posterior covariances of a GP.
This avoids using GPytorch's default forward pass so that we can
do cholesky-based predictions and low rank updates.
"""
def __init__(
self,
covar_module,
mean_module,
noise_module,
train_X,
train_y,
) -> None:
self.covar_module = covar_module
self.mean_module = mean_module
self.noise_module = noise_module
self.train_X = train_X
self.train_y = train_y
# prepare the cache for the posterior covariance - M^{-1}
# We don't actually invert M, but we use the cholesky decomposition.
noise = noise_module(train_X) # train_train_covar has sigma^2 added to the diagonal.
self.train_train_covar = self.covar_module(train_X).to_dense() + noise.to_dense()
self.train_train_covar_chol = torch.linalg.cholesky(self.train_train_covar)
# prepare the cache for the posterior mean - M^{-1} * y
train_mean = mean_module(train_X).squeeze(-1)
train_labels_offset = (self.train_y.squeeze(-1) - train_mean).unsqueeze(-1)
self.mean_cache = torch.cholesky_solve(train_labels_offset, self.train_train_covar_chol).squeeze(-1)
# i see no reason why we should not compile this.
# the shape of X won't change in repeated calls when
# optimizing X.
# @torch.compile #need to downgrade to python3.10
[docs]
def predict_covar(self, X, test_train_covar=None):
"""
Basic code taken from exact_predictive_covar in GPytorch.
NOTE this supports both batch mode (b,q,d) and single mode (q,d).
"""
test_test_covar = self.covar_module(X).to_dense()
if test_train_covar is None:
test_train_covar = self.covar_module(X, self.train_X).to_dense()
train_test_covar = test_train_covar.transpose(-1, -2)
covar_correction_rhs = torch.cholesky_solve(train_test_covar, self.train_train_covar_chol)
posterior = test_test_covar + test_train_covar @ covar_correction_rhs.mul(-1)
return posterior
[docs]
def forward(self, X):
"""
Get mean and covariance of the GP at X.
"""
test_train_covar = self.covar_module(X, self.train_X).to_dense()
return self.predict_mean(X, test_train_covar), self.predict_covar(X, test_train_covar)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
[docs]
def predict_mean(self, X, test_train_covar=None):
test_mean = self.mean_module(X)
if test_train_covar is None:
test_train_covar = self.covar_module(X, self.train_X).to_dense()
res = (test_train_covar @ self.mean_cache).squeeze(-1)
res = res + test_mean
return res
[docs]
def augmented_covariance(self, new_X):
"""
Add new_X to the training set, then compute posterior covariance of new_X.
Use low rank update to avoid recomputing the entire cholesky decomposition.
"""
# find the q-batch dimension.
q_dim = 0 if len(new_X.shape) == 2 else 1
if len(new_X.shape) > 2:
# need batched train_X
train_X = self.train_X.unsqueeze(0).expand(new_X.shape[0], -1, -1)
else:
train_X = self.train_X
# TODO maybe use some indexing to avoid too many kernel calls.
test_test_covar = self.covar_module(new_X).to_dense()
test_train_aug_covar = self.covar_module(new_X, torch.cat([train_X, new_X], axis=q_dim)).to_dense()
train_aug_test_covar = test_train_aug_covar.transpose(-1, -2)
train_train_covar_chol_aug = self.update_chol(
self.train_train_covar_chol,
self.covar_module(train_X, new_X).to_dense(), # old-new covar
test_test_covar + self.noise_module(new_X).to_dense() # new-new covar with noise
)
covar_correction_rhs = torch.cholesky_solve(train_aug_test_covar, train_train_covar_chol_aug)
posterior = test_test_covar + test_train_aug_covar @ covar_correction_rhs.mul(-1)
return posterior
[docs]
@staticmethod
def update_chol(L, B, C):
"""Update cholesky decomposition of M to M_aug.
Args:
L (np.ndarray): Cholesky decomposition of M (n, n) / (b, n, n)
B (np.ndarray): old-new covar (n, q) / (b, n, q)
C (np.ndarray): new-new covar (q, q) / (b, q, q)
NOTE: C needs to include the noise on the diagonal.
Returns:
L_aug: Cholesky decomposition of M_aug (n+q, n+q) / (b, n+q, n+q)
"""
if len(B.shape) > 2:
# ensure B and C are both batch mode.
assert B.shape[0] == C.shape[0]
X = torch.linalg.solve_triangular(L, B, upper=False).transpose(-1,-2)
# Calculate S (Schur complement)
S = C - torch.matmul(X, X.transpose(-1,-2))#X @ X.T
# Calculate Y
Y = torch.linalg.cholesky(S, upper=False)
# make L with a batch dim and repeat.
if len(B.shape) > 2:
L_broadcasted = L.unsqueeze(0)
L_broadcasted = L_broadcasted.expand(B.shape[0], -1, -1)
else:
L_broadcasted = L
# Combine as [[L, X], [0, Y]]
L_aug = torch.cat([
torch.cat([L_broadcasted, torch.zeros_like(B)], axis=-1),
torch.cat([X, Y], axis=-1)
], axis=-2)
return L_aug
[docs]
def update_covar_one_point(
covar: torch.Tensor,
x_train: torch.Tensor, # shape (N,d)
x_augmented: torch.Tensor, # shape (N+Q,d)
new_x: torch.Tensor, # shape (1,d)
new_x_idx : int, # index of new_x in x_augmented
kernel: torch.nn.Module,
):
# e_a
kronecker_delta = torch.zeros(x_augmented.shape[0] - x_train.shape[0]) # shape (Q,)
kronecker_delta[new_x_idx - x_train.shape[0]] = 1
kronecker_delta_augmented = torch.cat([torch.zeros(x_train.shape[0]), kronecker_delta], axis=0) # shape (N+Q)
x_aug_replaced = x_augmented.clone()
x_aug_replaced[new_x_idx] = new_x
# compute delta vectors.
# TODO shape check - maybe need some transposes/dummy dimensions.
delta_k_aa = kernel(new_x) - kernel(x_augmented[new_x_idx])
delta_k = kernel(x_train, new_x) - kernel(x_train, x_augmented[new_x_idx]) - 0.5 * delta_k_aa * kronecker_delta
delta_k_A = kernel(x_aug_replaced, new_x) - kernel(x_augmented, x_augmented[new_x_idx]) - 0.5 * delta_k_aa * kronecker_delta_augmented
delta_m_A = 0
delta_k_D = 0
# x being the batch points.
# TODO clean up signature. need x_train, x, and new_x+idx
delta_k_tilde = delta_k - kernel(x, x_train)
covar_updated = covar + delta_k_tilde @ kronecker_delta.T + kronecker_delta @ delta_k_tilde.T