from functools import cached_property
from typing import NamedTuple, List, Type
from jax import numpy as jnp, tree_map
from jax._src.scipy.linalg import solve_triangular
from jaxns import PriorModelGen, Prior, Categorical
from jaxns.internals.types import float_type
from tensorflow_probability.substrates import jax as tfp
from bojaxns.base import _assert_rank, _assert_same_leading_dim, ConditionalPredictive, ConditionalPredictiveFactory, \
AbstractAcquisition, AcquisitionFactory
[docs]
tfpd = tfp.distributions
[docs]
def log_normal(x, mean, cov):
L = jnp.linalg.cholesky(cov)
# U, S, Vh = jnp.linalg.svd(cov)
log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))
dx = x - mean
dx = solve_triangular(L, dx, lower=True)
# U S Vh V 1/S Uh
# pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) - log_det - 0.5 * maha
return log_likelihood
[docs]
def log_normal_with_mask(x, mean, cov, sigma):
"""
Computes log-Normal density in a numerically stable way so that sigma can contain +inf for masked data.
Args:
x: RV value
mean: mean of Gaussian
cov: covariance of underlying, minus the obs. covariance
sigma: stddev's of obs. error, inf encodes an outlier.
Returns: a normal density for all points not of inf stddev obs. error.
"""
C = cov / (sigma[:, None] * sigma[None, :]) + jnp.eye(cov.shape[0])
L = jnp.linalg.cholesky(C)
Ls = sigma[:, None] * L
log_det = jnp.sum(jnp.where(jnp.isinf(sigma), 0., jnp.log(jnp.diag(Ls))))
dx = (x - mean)
dx = solve_triangular(L, dx / sigma, lower=True)
maha = dx @ dx
log_likelihood = -0.5 * jnp.sum(~jnp.isinf(sigma)) * jnp.log(2. * jnp.pi) \
- log_det \
- 0.5 * maha
return log_likelihood
[docs]
class GaussianProcessData(NamedTuple):
[docs]
sample_size: jnp.ndarray
[docs]
class NotEnoughData(Exception):
pass
def _ensure_gaussian_process_data(data: GaussianProcessData) -> GaussianProcessData:
data = tree_map(lambda x: jnp.asarray(x, float_type), data)
_assert_rank(2, U=data.U)
_assert_rank(1, sample_size=data.sample_size, Y=data.Y, Y_var=data.Y_var)
_assert_same_leading_dim(*data)
if data.Y.size < 2:
raise NotEnoughData('Need more samples to form mean and variance of data.')
return data
[docs]
class GaussianProcessConditionalPredictive(ConditionalPredictive):
def __init__(self,
data: GaussianProcessData,
kernel: tfp.math.psd_kernels.PositiveSemidefiniteKernel,
variance: jnp.ndarray,
mean: jnp.ndarray):
self._data = _ensure_gaussian_process_data(data)
self._kernel = kernel
self._variance = variance
self._mean = mean
def _ndims(self):
return self._data.U.shape[-1]
[docs]
def posterior(self):
return self.__call__(self._data.U)
def _marginal_likelihood_with_mask(self):
Kxx = self._kernel.matrix(self._data.U, self._data.U)
no_uncert_data = jnp.isnan(self._data.Y_var)
variance = jnp.where(no_uncert_data,
self._variance + self._variance / jnp.sqrt(self._data.sample_size),
self._data.Y_var + self._variance / jnp.sqrt(self._data.sample_size))
sigma = jnp.sqrt(jnp.maximum(1e-6, variance))
return log_normal_with_mask(x=self._data.Y, mean=self._mean, cov=Kxx, sigma=sigma)
def _marginal_likelihood(self):
Kxx = self._kernel.matrix(self._data.U, self._data.U)
no_uncert_data = jnp.isnan(self._data.Y_var)
variance = jnp.where(no_uncert_data,
self._variance + self._variance / jnp.sqrt(self._data.sample_size),
self._data.Y_var + self._variance / jnp.sqrt(self._data.sample_size))
data_cov = jnp.diag(variance)
return log_normal(self._data.Y, self._mean, Kxx + data_cov)
[docs]
def marginal_likelihood(self):
return self._marginal_likelihood_with_mask()
def _posterior_with_mask(self, U_star: jnp.ndarray, cov: bool = False):
Kxx = self._kernel.matrix(self._data.U, self._data.U)
Kxs = self._kernel.matrix(self._data.U, U_star)
Kss = self._kernel.matrix(U_star, U_star)
no_uncert_data = jnp.isnan(self._data.Y_var)
variance = jnp.where(no_uncert_data,
self._variance + self._variance / jnp.sqrt(self._data.sample_size),
self._data.Y_var + self._variance / jnp.sqrt(self._data.sample_size))
std_dev = jnp.sqrt(jnp.maximum(1e-6, variance))
L = jnp.linalg.cholesky(Kxx / (std_dev[:, None] * std_dev[None, :]) + jnp.eye(std_dev.size))
# L = jnp.where(jnp.isnan(L), jnp.eye(L.shape[0])/sigma, L)
J = solve_triangular(L, Kxs / std_dev[:, None],
lower=True) # same J as below, but safely taking into account inf mask.
post_cov_s = Kss - J.T @ J
dY = self._data.Y - self._mean
dX = solve_triangular(L, dY / std_dev, lower=True)
post_mu_s = self._mean + J.T @ dX # mu - J^T L^-1 dY = mu - J^T dX
if cov:
return post_mu_s, post_cov_s
return post_mu_s, jnp.diag(post_cov_s)
def _posterior(self, U_star: jnp.ndarray, cov: bool = False):
Kxx = self._kernel.matrix(self._data.U, self._data.U)
no_uncert_data = jnp.isnan(self._data.Y_var)
variance = jnp.where(no_uncert_data,
self._variance + self._variance / jnp.sqrt(self._data.sample_size),
self._data.Y_var + self._variance / jnp.sqrt(self._data.sample_size))
data_cov = jnp.diag(variance)
Kxs = self._kernel.matrix(self._data.U, U_star)
Kss = self._kernel.matrix(U_star, U_star)
L = jnp.linalg.cholesky(Kxx + data_cov)
# inv(LL^T) = L^-T L^-1
J = solve_triangular(L, Kxs, lower=True)
post_cov_s = Kss - J.T @ J
H = solve_triangular(L.T, J, lower=False)
dY = self._data.Y - self._mean
post_mu_s = self._mean + H.T @ dY
if cov:
return post_mu_s, post_cov_s
return post_mu_s, jnp.diag(post_cov_s)
[docs]
def __call__(self, U_star: jnp.ndarray, cov: bool = False):
return self._posterior_with_mask(U_star=U_star, cov=cov)
[docs]
class GaussianProcessConditionalPredictiveFactory(ConditionalPredictiveFactory):
def __init__(self, data: GaussianProcessData):
self._data = _ensure_gaussian_process_data(data)
[docs]
def ndims(self):
return self._data.U.shape[-1]
[docs]
def build_prior_model(self):
amplitude_scale = 2 * jnp.std(self._data.Y)
length_scale_scale = jnp.max(self._data.U, axis=0) - jnp.min(self._data.U, axis=0)
variance_scale = jnp.std(self._data.Y)
mean_loc = jnp.mean(self._data.Y)
mean_scale = jnp.std(self._data.Y)
def prior_model() -> PriorModelGen:
amplitude = yield Prior(tfpd.Uniform(high=amplitude_scale), name='amplitude')
length_scale = yield Prior(tfpd.Uniform(high=length_scale_scale), name='length_scale')
variance = yield Prior(tfpd.Uniform(high=variance_scale), name='variance') # measurement variance
mean = yield Prior(tfpd.Normal(loc=mean_loc, scale=mean_scale), name='mean')
kernel_select = yield Categorical(parametrisation='gumbel_max', logits=jnp.zeros(len(self.psd_kernels)),
name='kernel_select')
return amplitude, length_scale, variance, mean, kernel_select
return prior_model
@cached_property
[docs]
def psd_kernels(self) -> List[Type[tfp.math.psd_kernels.PositiveSemidefiniteKernel]]:
return [
tfp.math.psd_kernels.MaternThreeHalves,
tfp.math.psd_kernels.ExponentiatedQuadratic,
tfp.math.psd_kernels.MaternOneHalf
]
[docs]
def __call__(self, **samples) -> GaussianProcessConditionalPredictive:
amplitude = samples.get('amplitude')
length_scale = samples.get('length_scale')
kernel_select = samples.get('kernel_select')
mask = jnp.where(jnp.arange(len(self.psd_kernels)) == kernel_select, 1., 0.)
kernels = []
for i, psd_kernel in enumerate(self.psd_kernels):
base_kernel = psd_kernel(amplitude=amplitude, length_scale=None)
_kernel = tfp.math.psd_kernels.FeatureTransformed(base_kernel,
transformation_fn=lambda x, _1, _2: x / length_scale)
_kernel = tfp.math.psd_kernels.Constant(mask[i]) * _kernel
kernels.append(_kernel)
kernel = sum(kernels[1:], kernels[0])
variance = samples.get('variance')
mean = samples.get('mean')
return GaussianProcessConditionalPredictive(
data=self._data,
kernel=kernel,
variance=variance,
mean=mean
)
[docs]
class ExpectedImprovementAcquisition(AbstractAcquisition):
"""
A class that represents the heteroscedastic expected improvement acquisition function.
"""
def __init__(self, conditional_predictive: GaussianProcessConditionalPredictive):
self._conditional_predictive = conditional_predictive
@staticmethod
def _expected_improvement(post_mu_x_max: jnp.ndarray, post_mu_s: jnp.ndarray,
post_var_s: jnp.ndarray) -> jnp.ndarray:
post_stddev_s = jnp.sqrt(jnp.maximum(1e-6, post_var_s))
posterior_pdf = tfpd.Normal(loc=0., scale=1.)
u = (post_mu_s - post_mu_x_max) / post_stddev_s
return post_stddev_s * (posterior_pdf.prob(u) + u * posterior_pdf.cdf(u))
[docs]
def __call__(self, u_star: jnp.ndarray):
post_mu_x, post_var_x = self._conditional_predictive.posterior()
post_mu_x_max = jnp.max(post_mu_x)
post_mu_s, post_var_s = self._conditional_predictive(u_star[None, :])
ei = ExpectedImprovementAcquisition._expected_improvement(
post_mu_x_max=post_mu_x_max,
post_mu_s=post_mu_s,
post_var_s=post_var_s
)
return jnp.reshape(ei, ())
[docs]
class ScaledExpectedImprovementAcquisition(AbstractAcquisition):
"""
A class that represents the heteroscedastic expected improvement acquisition function.
"""
def __init__(self, condition_predictive: GaussianProcessConditionalPredictive):
self._condition_predictive = condition_predictive
@staticmethod
def _expected_squared_improvement(post_mu_x_max: jnp.ndarray, post_mu_s: jnp.ndarray, post_var_s: jnp.ndarray):
post_stddev_s = jnp.sqrt(jnp.maximum(1e-6, post_var_s))
posterior_pdf = tfpd.Normal(loc=0., scale=1.)
u = (post_mu_s - post_mu_x_max) / post_stddev_s
return post_var_s * (u * posterior_pdf.prob(u) + (u ** 2 + 1.) * posterior_pdf.cdf(u))
[docs]
def __call__(self, u_star: jnp.ndarray):
post_mu_x, post_var_x = self._condition_predictive.posterior()
post_mu_x_max = jnp.max(post_mu_x)
post_mu_s, post_var_s = self._condition_predictive(u_star[None, :])
ei2 = ScaledExpectedImprovementAcquisition._expected_squared_improvement(
post_mu_x_max=post_mu_x_max,
post_mu_s=post_mu_s,
post_var_s=post_var_s
)
ei = ExpectedImprovementAcquisition._expected_improvement(
post_mu_x_max=post_mu_x_max,
post_mu_s=post_mu_s,
post_var_s=post_var_s
)
scaled_ei = ei / jnp.sqrt(jnp.maximum(1e-6, ei2 - ei ** 2))
return jnp.reshape(scaled_ei, ())
[docs]
class TopTwoAcquisition(AbstractAcquisition):
def __init__(self, condition_predictive: GaussianProcessConditionalPredictive, u1: jnp.ndarray):
self._condition_predictive = condition_predictive
u1 = jnp.asarray(u1, float_type)
self._u1 = u1
[docs]
def __call__(self, u_star: jnp.ndarray):
S = jnp.stack([u_star, self._u1], axis=0)
post_mu_s, post_K_s = self._condition_predictive(S, cov=True)
sigma2 = post_K_s[0, 0] + post_K_s[1, 1] - 2. * post_K_s[0, 1]
ei = ExpectedImprovementAcquisition._expected_improvement(
post_mu_x_max=post_mu_s[1],
post_mu_s=post_mu_s[0],
post_var_s=sigma2
)
return jnp.reshape(ei, ())
[docs]
class ScaledTopTwoAcquisition(AbstractAcquisition):
def __init__(self, condition_predictive: GaussianProcessConditionalPredictive, u1: jnp.ndarray):
self._condition_predictive = condition_predictive
u1 = jnp.asarray(u1, float_type)
self._u1 = u1
[docs]
def __call__(self, u_star: jnp.ndarray):
S = jnp.stack([u_star, self._u1], axis=0)
post_mu_s, post_K_s = self._condition_predictive(S, cov=True)
sigma2 = post_K_s[0, 0] + post_K_s[1, 1] - 2. * post_K_s[0, 1]
ei2 = ScaledExpectedImprovementAcquisition._expected_squared_improvement(
post_mu_x_max=post_mu_s[1],
post_mu_s=post_mu_s[0],
post_var_s=sigma2
)
ei = ExpectedImprovementAcquisition._expected_improvement(
post_mu_x_max=post_mu_s[1],
post_mu_s=post_mu_s[0],
post_var_s=sigma2
)
scaled_ei = ei / jnp.sqrt(jnp.maximum(1e-6, ei2 - ei ** 2))
return jnp.reshape(scaled_ei, ())
[docs]
class ExpectedImprovementAcquisitionFactory(AcquisitionFactory):
def __init__(self, conditional_predictive_factory: GaussianProcessConditionalPredictiveFactory):
self._conditional_predictive_factory = conditional_predictive_factory
[docs]
def __call__(self, **sample) -> AbstractAcquisition:
conditional_predictive = self._conditional_predictive_factory(**sample)
return ExpectedImprovementAcquisition(conditional_predictive=conditional_predictive)
[docs]
class ScaledExpectedImprovementAcquisitionFactory(AcquisitionFactory):
def __init__(self, conditional_predictive_factory: GaussianProcessConditionalPredictiveFactory):
self._conditional_predictive_factory = conditional_predictive_factory
[docs]
def __call__(self, **sample) -> AbstractAcquisition:
conditional_predictive = self._conditional_predictive_factory(**sample)
return ScaledExpectedImprovementAcquisition(condition_predictive=conditional_predictive)
[docs]
class TopTwoAcquisitionFactory(AcquisitionFactory):
def __init__(self, conditional_predictive_factory: GaussianProcessConditionalPredictiveFactory, u1: jnp.ndarray):
self._conditional_predictive_factory = conditional_predictive_factory
self._u1 = u1
[docs]
def __call__(self, **sample) -> AbstractAcquisition:
conditional_predictive = self._conditional_predictive_factory(**sample)
return TopTwoAcquisition(condition_predictive=conditional_predictive,
u1=self._u1)
[docs]
class ScaledTopTwoAcquisitionFactory(AcquisitionFactory):
def __init__(self, conditional_predictive_factory: GaussianProcessConditionalPredictiveFactory, u1: jnp.ndarray):
self._conditional_predictive_factory = conditional_predictive_factory
self._u1 = u1
[docs]
def __call__(self, **sample) -> AbstractAcquisition:
conditional_predictive = self._conditional_predictive_factory(**sample)
return ScaledTopTwoAcquisition(condition_predictive=conditional_predictive,
u1=self._u1)