Source code for bojaxns.base

from abc import abstractmethod
from typing import NamedTuple, Dict

import chex
from jax import numpy as jnp, tree_map, vmap
from jax.random import PRNGKey
from jaxns import resample, PriorModelType


[docs] class AbstractAcquisition: """ A class that represents any acquisition function. All acquisition functions take a point in the U-domain and returns a metric that gives a proxy as to how valuable it would be to try that point. All acquisition values only make sense relatively. """ @abstractmethod
[docs] def __call__(self, u_star: jnp.ndarray): ...
def _assert_rank(rank: int, **kwargs): for name, t in kwargs.items(): if len(t.shape) != rank: raise ValueError(f"{name} shoue be rank {rank} got {t.shape}.") def _assert_same_leading_dim(*args): n = set() for arg in args: n.add(arg.shape[0]) if len(n) > 1: raise ValueError(f"Got mismatched leading dimensions: {n}")
[docs] class ConditionalPredictive: @abstractmethod def _ndims(self): ... @property
[docs] def ndims(self): return self._ndims()
@abstractmethod
[docs] def posterior(self): ...
@abstractmethod
[docs] def marginal_likelihood(self): ...
@abstractmethod
[docs] def __call__(self, U_star: jnp.ndarray, cov: bool = False): ...
[docs] class MarginalisationData(NamedTuple):
[docs] samples: Dict[str, chex.Array]
[docs] log_dp_mean: chex.Array
[docs] class ConditionalPredictiveFactory: @abstractmethod
[docs] def ndims(self): ...
@abstractmethod
[docs] def build_prior_model(self) -> PriorModelType: ...
@abstractmethod
[docs] def __call__(self, **samples) -> ConditionalPredictive: ...
[docs] class AcquisitionFactory: @abstractmethod
[docs] def __call__(self, **sample) -> AbstractAcquisition: ...
[docs] class MarginalisedAcquisitionFunction(AbstractAcquisition): """ Class that represents a marginalisation of an acquisition function over samples. """ def __init__(self, key: PRNGKey, ns_results: MarginalisationData, acquisition_factory: AcquisitionFactory, S: int): self._acquisition_factory = acquisition_factory self._key = key self._ns_results = ns_results self._S = int(S)
[docs] def __call__(self, u_star: jnp.ndarray): def _eval(**sample): acquisition = self._acquisition_factory(**sample) return acquisition(u_star=u_star) samples = resample(self._key, self._ns_results.samples, self._ns_results.log_dp_mean, S=self._S, replace=True) marginalised = tree_map(lambda marg: jnp.nanmean(marg, axis=0), vmap(_eval)(**samples)) return marginalised
# # return marginalise_static( # key=self._key, # samples=self._ns_results.samples, # log_weights=self._ns_results.log_dp_mean, # ESS=int(self._ns_results.ESS), # fun=_eval # )
[docs] class MarginalisedConditionalPredictive(ConditionalPredictive): """ Class that represents a marginalisation of an acquisition function over samples. """ def __init__(self, key: PRNGKey, ns_results: MarginalisationData, conditional_predictive_factory: ConditionalPredictiveFactory, S: int): self._conditional_predictive_factory = conditional_predictive_factory self._key = key self._ns_results = ns_results self._S = int(S) def _ndims(self): return self._conditional_predictive_factory.ndims()
[docs] def posterior(self): def _eval(**sample): conditional_predictive = self._conditional_predictive_factory(**sample) return conditional_predictive.posterior() samples = resample(self._key, self._ns_results.samples, self._ns_results.log_dp_mean, S=self._S, replace=True) marginalised = tree_map(lambda marg: jnp.nanmean(marg, axis=0), vmap(_eval)(**samples)) return marginalised
[docs] def marginal_likelihood(self): def _eval(**sample): conditional_predictive = self._conditional_predictive_factory(**sample) return conditional_predictive.marginal_likelihood() samples = resample(self._key, self._ns_results.samples, self._ns_results.log_dp_mean, S=self._S, replace=True) marginalised = tree_map(lambda marg: jnp.nanmean(marg, axis=0), vmap(_eval)(**samples)) return marginalised
[docs] def __call__(self, U_star: jnp.ndarray, cov: bool = False): def _eval(**sample): conditional_predictive = self._conditional_predictive_factory(**sample) return conditional_predictive(U_star=U_star, cov=cov) samples = resample(self._key, self._ns_results.samples, self._ns_results.log_dp_mean, S=self._S, replace=True) marginalised = tree_map(lambda marg: jnp.nanmean(marg, axis=0), vmap(_eval)(**samples)) return marginalised