import warnings
from copy import deepcopy
from enum import Enum
from typing import Optional, Tuple, Union
import numpy as np
import torch
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import BotorchWarning
from botorch.models.model import Model
from botorch.utils.transforms import t_batch_mode_transform, concatenate_pending_points
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from linear_operator.utils.cholesky import psd_safe_cholesky
from .utils.cholesky_inference import GPPosteriorPredictor
[docs]
class LogDetMethod(Enum):
"""Used to specify the method for computing the log determinant of the covariance matrices in the :class:`~BatchedEnergyEntropyBO` acquisition function.
SVD
Computes the log determinant using singular value decomposition.
CHOLESKY
Computes the log determinant using the cholesky decomposition, taking advantage of the fact that the covariance matrix is positive definite. This
is not always numerically stable.
TORCH
Computes the log determinant using the default torch function `torch.logdet`. This is not always numerically stable.
"""
SVD = "svd"
CHOLESKY = "cholesky"
TORCH = "torch"
[docs]
class AugmentedPosteriorMethod(Enum):
"""
Used to specify the method for augmenting the GP model with new points
and computing the posterior covariance after augmentation in the :class:`~BatchedEnergyEntropyBO` acquisition function.
NAIVE
We keep a copy of the original model and augment it with the new points using the
``set_train_data`` method each time we evaluate the acquisition function. This is memory safe but slow.
CHOLESKY
We perform a low rank update to the cholesky decomposition of the training covariance, adding the new points.
This is fast, but circumvents the default GPyTorch inference in favor of cholesky-based predictions. Uses
:class:`~beebo.utils.cholesky_inference.GPPosteriorPredictor` to compute the augmented covariance.
GET_FANTASY_MODEL
We use the ``get_fantasy_model`` method of the GP model to get a new model with the new points. This is not memory safe
when running with gradients enabled.
"""
NAIVE = "naive"
CHOLESKY = "cholesky"
GET_FANTASY_MODEL = "get_fantasy_model" # NOTE this isn't memory safe yet
# TODO replace with a wrapper for GET_FANTASY_STRATEGY --> skip the internal model deepcopy
from linear_operator.operators import (DiagLinearOperator,
LowRankRootLinearOperator)
from linear_operator.utils.cholesky import psd_safe_cholesky
[docs]
def stable_softmax(x: torch.Tensor, beta: float, f_max: float = None, eps=1e-6, alpha=0.05):
if f_max is None:
x_scaled = beta * x
z = x_scaled - x_scaled.max(dim=-1, keepdim=True).values # (n x q+1)
z_exp = z.exp()
w = z_exp / z_exp.sum(dim=-1, keepdim=True) # (n x q)
else:
x_scaled = beta * x
beta_delta_x = x_scaled - x_scaled.max(dim=-1, keepdim=True).values
beta_delta_fmax = beta * f_max - x_scaled.max(dim=-1, keepdim=True).values
denominator = beta_delta_x.exp().sum(dim=-1, keepdim=True)
g = torch.stack([denominator * (1-alpha)/alpha, beta_delta_fmax.exp()], dim=-1)
g = g.min(dim=-1).values
w = beta_delta_x.exp() / (denominator+g)
return w
[docs]
def softmax_expectation_a_is_mean(mvn, softmax_beta, f_max=None):
means = mvn.mean # (n x q)
covar = mvn.covariance_matrix # (n x q x q)
lazy_covar = mvn.lazy_covariance_matrix # (n x q x q)
w = stable_softmax(means, softmax_beta, f_max)
W = DiagLinearOperator(w) - LowRankRootLinearOperator(w.unsqueeze(-1)) # (n x q x q)
U_inv = DiagLinearOperator(torch.ones(covar.shape[1], device=means.device)) + softmax_beta**2 * lazy_covar @ W
# avoid doing a solve for C_update.
col_difference = U_inv.solve(covar - covar @ w.unsqueeze(-1)) # this is C_update - C_update @ w.unsqueeze(-1)
nu_i_matrix = softmax_beta * col_difference + means.unsqueeze(-1) # (n x q x q)
c_i_vector = 0.5*softmax_beta**2 * (
torch.diagonal(col_difference, dim1=-2, dim2=-1)
- (w.unsqueeze(-1).mT @ col_difference).squeeze() # (n x 1 x q)
) # n x q
K = (1/U_inv.to_dense().det()).sqrt()
expectation = K * ((w.log() + c_i_vector).exp() * torch.diagonal(nu_i_matrix, dim1=-2, dim2=-1)).sum(dim=-1)
return expectation
[docs]
def softmax_expectation(mvn: MultivariateNormal, a: torch.Tensor, softmax_beta: float, f_max: float = None):
# NOTE we are using the simplified expressions that arise when the expansion point is the mean of the MVN.
shortcut_expectation = softmax_expectation_a_is_mean(mvn, softmax_beta, f_max)
expectation = shortcut_expectation
if torch.isnan(expectation).any():
import ipdb; ipdb.set_trace()
raise Exception('nan in expectation')
if torch.isinf(expectation).any():
raise Exception('inf in expectation')
return expectation
[docs]
class EnergyFunction(Enum):
"""Used to specify the energy function to be used in the :class:`~BatchedEnergyEntropyBO` acquisition function.
SOFTMAX
Implements maxBEEBO. This option will lead the acquisition function to focus more on the point with the highest expected improvement in the batch.
If this is chosen, :class:`~BatchedEnergyEntropyBO` accepts the additional arguments ``softmax_beta`` and ``f_max``.
``softmax_beta`` is a scalar representing the inverse temperature of the softmax function.
``f_max`` is a scalar representing the maximum value of the function to be optimized.
SUM
Implements meanBEEBO. This option will lead the acquisition function to focus on improving the overall batch.
"""
SOFTMAX = "softmax"
SUM = "sum"
[docs]
class BatchedEnergyEntropyBO(AnalyticAcquisitionFunction):
r"""
The BEEBO batch acquisition function. Jointly optimizes a batch of points by minimizing
the free energy of the batch.
Args:
model (GPyTorch Model): A fitted single-outcome GP model. Must be in batch mode if
candidate sets X will be.
temperature (float): A scalar representing the temperature.
Higher temperature leads to more exploration.
kernel_amplitude (float, optional): The amplitude of the kernel. Defaults to 1.0.
This is used to bring the temperature to a scale that is comparable to
UCB's hyperparameter `beta`.
posterior_transform (PosteriorTransform, optional): A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize (bool, optional): If True, consider the problem a maximization problem. Defaults to False.
logdet_method (str, optional): The method to use to compute the log determinant of the
covariance matrix. Should be one of the members of the :class:`~LogDetMethod` enum:
``"svd"``, ``"cholesky"``, or ``"torch"``. Defaults to ``"svd"``.
augment_method (str, optional): The method to use to augment the model with the new points
and computing the posterior covariance. Should be one of the members of the :class:`~AugmentedPosteriorMethod` enum:
``"naive"`` or ``"cholesky"``. Defaults to ``"naive"``.
energy_function (str, optional): The energy function to use in the BEEBO acquisition function.
Should be a string representing one of the members of the :class:`~EnergyFunction` enum: ``"softmax"`` or ``"sum"``.
"softmax" implements the maxBEEBO and "sum" implements the meanBEEBO. Defaults to ``"sum"``.
**kwargs: Additional arguments to be passed to the energy function.
"""
def __init__(
self,
model: Model,
temperature: float,
kernel_amplitude: float = 1.0,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[torch.Tensor] = None,
maximize: bool = True,
logdet_method: Union[str, LogDetMethod] = "svd",
augment_method: Union[str, AugmentedPosteriorMethod] = "naive",
energy_function: Union[str, EnergyFunction] = "sum",
**kwargs,
) -> None:
super().__init__(model=model, posterior_transform=posterior_transform)
self.logdet_method = LogDetMethod[logdet_method.upper()] if isinstance(logdet_method, str) else logdet_method
self.augment_method = AugmentedPosteriorMethod[augment_method.upper()] if isinstance(augment_method, str) else augment_method
self.energy_function = EnergyFunction[energy_function.upper()] if isinstance(energy_function, str) else energy_function
if self.energy_function == EnergyFunction.SOFTMAX:
self.softmax_beta = kwargs.get("softmax_beta", 1.0)
self.f_max = kwargs.get("f_max", None)
self.kernel_amplitude = kernel_amplitude
self.temperature = temperature * (self.kernel_amplitude)**(1/2)
self.maximize = maximize
self.set_X_pending(X_pending)
if self.augment_method == AugmentedPosteriorMethod.CHOLESKY:
self.predictor = GPPosteriorPredictor(
model.covar_module,
model.mean_module,
model.likelihood.noise_covar,
train_X=model.train_inputs[0],
train_y=model.train_targets,
)
elif self.augment_method == AugmentedPosteriorMethod.NAIVE:
# for augmentation, we keep a copy of the original model
# if we make a copy in the forward pass only, we get a memory leak
self.augmented_model = deepcopy(model)
# self.summary_fn = torch.sum
[docs]
@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""Evaluate the free energy of the candidate set X.
Args:
X: A `(b1 x ... bk) x q x d`-dim batched tensor of `d`-dim design points.
Returns:
A `(b1 x ... bk)`-dim tensor of BOSS values at the
given design points `X`.
"""
# required for gradients in augmented GPs.
with settings.detach_test_caches(False):
self.model.eval()
# Entropy term.
f_preds = self.model(X) # this gets p(f* | x*, X, y) with x* being test points.
posterior_cov = f_preds.covariance_matrix # == C'_D
posterior_means = f_preds.mean
if self.augment_method == AugmentedPosteriorMethod.NAIVE:
# augment the training data with the test points
X_train_original = self.model.train_inputs[0]
Y_train_original = self.model.train_targets
X_train = X_train_original.expand(X.shape[0], X_train_original.shape[0], X_train_original.shape[1]) # (n_batch, n_train, dim)
Y_train = Y_train_original.expand(X.shape[0], Y_train_original.shape[0]) # (n_batch, n_train)
X_train_augmented = torch.cat([X_train, X], dim=1) # (n_batch, n_train + n_aug, dim)
Y_train_augmented = torch.cat([Y_train, torch.zeros_like(X[:,:,1])], dim=1) # (n_batch, n_train + n_aug)
self.augmented_model.set_train_data(X_train_augmented, Y_train_augmented, strict=False)
f_preds_augmented = self.augmented_model(X)
posterior_means_augmented = f_preds_augmented.mean
posterior_cov_augmented = f_preds_augmented.covariance_matrix
elif self.augment_method == AugmentedPosteriorMethod.CHOLESKY:
posterior_cov_augmented = self.predictor.augmented_covariance(X)
posterior_means_augmented = torch.zeros_like(posterior_means) #not used
elif self.augment_method == AugmentedPosteriorMethod.GET_FANTASY_MODEL:
fantasy_model = self.model.get_fantasy_model(X, torch.zeros_like(X[:,:,1]))
f_preds_augmented = fantasy_model(X)
posterior_means_augmented = f_preds_augmented.mean
posterior_cov_augmented = f_preds_augmented.covariance_matrix
if self.logdet_method == LogDetMethod.CHOLESKY:
# use cholesky decomposition to compute logdet, fallback to svd if fails
with settings.cholesky_max_tries(1):
try:
posterior_cov_logdet = f_preds.lazy_covariance_matrix.logdet()
# also trigger exception when any nan in result
if torch.isnan(posterior_cov_logdet).any():
raise Exception('nan in logdet')
elif torch.isinf(posterior_cov_logdet).any():
raise Exception('inf in logdet')
except Exception as e:
print(f'Cholesky failed: {e}')
_, s, _ = torch.svd(posterior_cov)
posterior_cov_logdet = torch.sum(torch.log(s), dim=-1)
try:
if self.augment_method == AugmentedPosteriorMethod.CHOLESKY:
# there is no lazy_covariance_matrix if we use cholesky augmentation
chol = psd_safe_cholesky(posterior_cov_augmented)
posterior_cov_augmented_logdet = chol.diagonal(dim1=-2, dim2=-1).log().sum(-1) * 2
# posterior_cov_augmented_logdet = torch.logdet(posterior_cov_augmented)
else:
posterior_cov_augmented_logdet = f_preds_augmented.lazy_covariance_matrix.logdet()
# also trigger exception when any nan in result
if torch.isnan(posterior_cov_augmented_logdet).any():
raise Exception('nan in logdet')
elif torch.isinf(posterior_cov_augmented_logdet).any():
raise Exception('inf in logdet')
except Exception as e:
print(f'Cholesky failed: {e}')
_, s, _ = torch.svd(posterior_cov_augmented)
posterior_cov_augmented_logdet = torch.sum(torch.log(s), dim=-1)
elif self.logdet_method == LogDetMethod.SVD:
# use svd to compute logdet
s = torch.linalg.svdvals(posterior_cov_augmented)
# s[s==0] = 1e-20 # avoid nan # same but autograd friendly
s = torch.where(s==0, torch.ones_like(s) * 1e-20, s)
posterior_cov_augmented_logdet = torch.sum(torch.log(s), dim=-1)
s = torch.linalg.svdvals(posterior_cov)
s = torch.where(s==0, torch.ones_like(s) * 1e-20, s) # avoid nan
posterior_cov_logdet = torch.sum(torch.log(s), dim=-1)
elif self.logdet_method == LogDetMethod.TORCH:
posterior_cov = posterior_cov + torch.eye(posterior_cov.shape[-1], device=posterior_cov.device) * 1e-06
posterior_cov_logdet = torch.logdet(posterior_cov)
posterior_cov_augmented = posterior_cov_augmented + torch.eye(posterior_cov_augmented.shape[-1], device=posterior_cov_augmented.device) * 1e-0
posterior_cov_augmented_logdet = torch.logdet(posterior_cov_augmented)
else:
raise NotImplementedError(f'logdet method {self.logdet_method} not implemented')
if torch.isinf(posterior_cov_augmented_logdet).any():
print('augmented cov logdet is inf')
if torch.isinf(posterior_cov_logdet).any():
print('cov logdet is inf')
if torch.isnan(posterior_cov_augmented_logdet).any():
print('augmented cov logdet is nan')
if torch.isnan(posterior_cov_logdet).any():
print('cov logdet is nan')
information_gain = 0.5* (posterior_cov_logdet - posterior_cov_augmented_logdet)
if self.energy_function == EnergyFunction.SUM:
summary_posterior = torch.sum(posterior_means, dim=1)
summary_augmented = torch.sum(posterior_means_augmented, dim=1)
elif self.energy_function == EnergyFunction.SOFTMAX:
summary_posterior = softmax_expectation(f_preds, a=f_preds.mean, softmax_beta=self.softmax_beta, f_max=self.f_max)
summary_posterior = summary_posterior * f_preds.mean.shape[1] # multiply by q to make it scale linearly with q, like logdet
# this is a dummy thing for memory leaks
summary_augmented = torch.sum(posterior_means_augmented, dim=1)
if self.maximize:
# maximize fn value + gain
acq_value = summary_posterior + self.temperature * information_gain
else:
acq_value = (-1) * summary_posterior + self.temperature * information_gain
acq_value += summary_augmented*0 # this prevents memory leaks.
# print('acq', 'info gain', 'expect.', 'sum','max')
# print_array = torch.stack([acq_value, information_gain, summary_posterior, posterior_means.sum(1), posterior_means.max(1).values], dim=1) # (num_restarts, 5)
# print_array = np.array_str(print_array.detach().cpu().numpy().mean(axis=0), precision=3, suppress_small=True)
# print(print_array)
return acq_value
# NOTE the base AnalyticAcquisitionFunction class does not support X_pending
[docs]
def set_X_pending(self, X_pending: Optional[torch.Tensor] = None) -> None:
r"""Informs the acquisition function about pending design points.
Args:
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
"""
if X_pending is not None:
# when doing sequential, stuff will have gradients. no point in
# warning about it.
# if X_pending.requires_grad:
# warnings.warn(
# "Pending points require a gradient but the acquisition function"
# " will not provide a gradient to these points.",
# BotorchWarning,
# )
self.X_pending = X_pending.detach().clone()
else:
self.X_pending = X_pending
# NOTE the methods below are only there for better readability.
# Due to memory leaks and for numerical reasons, they are not used in the actual code.
# We keep them here for future reference, as a minimal example of how to compute the
# two terms of the BEE-BOSS acquisition function.
[docs]
@t_batch_mode_transform()
def compute_energy(self, X: torch.Tensor) -> torch.Tensor:
"""Evaluate the energy of the candidate set X.
Args:
X: A `(b1 x ... bk) x q x d`-dim batched tensor of `d`-dim design points.
Returns:
A `(b1 x ... bk)`-dim tensor of BOSS values at the
given design points `X`.
"""
with settings.detach_test_caches(False):
self.model.eval()
# Enthalpy term.
f_preds = self.model(X) # this gets p(f* | x*, X, y) with x* being test points.
posterior_means = f_preds.mean
summary = self.summary_fn(posterior_means, dim=1)
if not self.maximize:
# minimize fn value + maximize gain
summary = (-1) * summary
return summary
[docs]
@t_batch_mode_transform()
def compute_entropy(self, X: torch.Tensor) -> torch.Tensor:
"""Evaluate the energy of the candidate set X.
Args:
X: A `(b1 x ... bk) x q x d`-dim batched tensor of `d`-dim design points.
Returns:
A `(b1 x ... bk)`-dim tensor of information gain values at the
given design points `X`.
"""
with settings.detach_test_caches(False):
self.model.eval()
# Entropy term.
f_preds = self.model(X) # this gets p(f* | x*, X, y) with x* being test points.
posterior_cov = f_preds.covariance_matrix # == C'_D
## augment observations with x' and dummy y' (because gpytorch requires them)
model_augmented = self.model.get_fantasy_model(X, torch.zeros_like(X[:,:,1]))
f_preds = model_augmented(X)
posterior_cov_augmented = f_preds.covariance_matrix # == C'_D_D'
posterior_cov_augmented = posterior_cov_augmented + torch.eye(posterior_cov_augmented.shape[-1], device=posterior_cov_augmented.device) * 1e-04
information_gain = torch.logdet(posterior_cov) - torch.logdet(posterior_cov_augmented)
return information_gain