Source code for bojaxns.gaussian_process_formulation.bayesian_optimiser

import os.path

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from chex import PRNGKey
from jax import random, numpy as jnp, vmap
from jax._src.lax.control_flow import scan
from jaxns import Model, DefaultNestedSampler
from jaxns.internals.types import float_type, NestedSamplerResults

from bojaxns.base import AbstractAcquisition, MarginalisedAcquisitionFunction, MarginalisationData
from bojaxns.experiment import OptimisationExperiment
from bojaxns.gaussian_process_formulation.distribution_math import GaussianProcessData, \
    GaussianProcessConditionalPredictiveFactory, ExpectedImprovementAcquisitionFactory, TopTwoAcquisitionFactory
from bojaxns.gaussian_process_formulation.multi_step_lookahead import run_multi_lookahead, convert_tree_to_graph

[docs] tfpb = tfp.bijectors
[docs] class BayesianOptimiser: def __init__(self, experiment: OptimisationExperiment, num_parallel_solvers: int = 1, beta: float = 0.5, S: int = 512): self._experiment = experiment self._num_parallel_solvers = num_parallel_solvers self._beta = beta self._S = int(S) self._data = self._prepare_data() def _prepare_data(self) -> GaussianProcessData: U = [] Y = [] Y_var = [] sample_size = [] # handle nans ==> illegal value min_val, max_val = np.inf, -np.inf for trial_id, trial in self._experiment.trials.items(): for ref_id, trial_update in trial.trial_updates.items(): if not np.isfinite(trial_update.objective_measurement): continue min_val = min(trial_update.objective_measurement, min_val) max_val = max(trial_update.objective_measurement, max_val) illegal_value = min_val - 0.1*(max_val - min_val) if not np.isfinite(illegal_value): illegal_value = 0. for trial_id, trial in self._experiment.trials.items(): if len(trial.trial_updates) == 0: continue samples = [] for ref_id, trial_update in trial.trial_updates.items(): if not np.isfinite(trial_update.objective_measurement): samples.append(illegal_value) else: samples.append(trial_update.objective_measurement) U.append(trial.U_value) Y.append(np.mean(samples)) if len(samples) < 2: Y_var.append(np.nan) else: Y_var.append(np.var(samples)) sample_size.append(len(samples)) U = jnp.asarray(U, float_type) Y = jnp.asarray(Y, float_type) Y_var = jnp.asarray(Y_var, float_type) sample_sizes = jnp.asarray(sample_size, float_type) data = GaussianProcessData(U=U, Y=Y, Y_var=Y_var, sample_size=sample_sizes) return data
[docs] def posterior_solve(self, key: PRNGKey) -> NestedSamplerResults: print("Performing posterior solve") conditional_predictive_factory = GaussianProcessConditionalPredictiveFactory(data=self._data) prior_model = conditional_predictive_factory.build_prior_model() def log_likelihood(amplitude, length_scale, variance, mean, kernel_select): """ P(Y|sigma, half_width) = N[Y, f, K] """ conditional_predictive = conditional_predictive_factory( amplitude=amplitude, length_scale=length_scale, variance=variance, mean=mean, kernel_select=kernel_select ) return conditional_predictive.marginal_likelihood() model = Model( prior_model=prior_model, log_likelihood=log_likelihood ) ns = DefaultNestedSampler( model=model, parameter_estimation=True, max_samples=1e5 ) termination_reason, state = jax.jit(ns)(key=key) results = ns.to_results(termination_reason, state) ns.summary(results) ns.plot_diagnostics(results) ns.plot_cornerplot(results) return results
@staticmethod def _random_search(search_key: PRNGKey, acquisition_function: AbstractAcquisition, ndims: int, batch_size: int, num_search: int): vmap_acquisition_function = vmap(acquisition_function) def body(carry, key): (acquisition_best, u_best) = carry U_star = random.uniform(key, shape=(batch_size, ndims), dtype=float_type) acquisition = vmap_acquisition_function(U_star) idx_max = jnp.argmax(acquisition) acquisiton_max = acquisition[idx_max] u_star_max = U_star[idx_max] better = acquisiton_max > acquisition_best acquisition_best = jnp.where(better, acquisiton_max, acquisition_best) u_best = jnp.where(better, u_star_max, u_best) # # Gradient-based improvement # bij = tfpb.Sigmoid() # x0 = bij.inverse(u_best) # opt_res = minimize(lambda x: -acquisition_function(u_star=bij.forward(x)), x0=x0, method='BFGS') # u_best = bij.forward(opt_res.x) # acquisition_best = -opt_res.fun return (acquisition_best, u_best), (acquisition, U_star) num_batches = num_search // batch_size (acquisition_best, u_best), (acquisition, U_star) = \ scan(body, (-jnp.inf, jnp.zeros(ndims, float_type)), random.split(search_key, num_batches) ) # concatenate from stack acquisition = jnp.reshape(acquisition, (-1,)) U_star = jnp.reshape(U_star, (acquisition.size, -1)) return (acquisition_best, u_best), (acquisition, U_star) @staticmethod def _multistep_lookahead_search(key: PRNGKey, data: GaussianProcessData, ns_results: MarginalisationData, batch_size: int, max_depth: int, num_simulations: int, branch_factor: int, S: int): print("Starting search.") u_best, policy_output = run_multi_lookahead( rng_key=key, data=data, ns_results=ns_results, batch_size=batch_size, max_depth=max_depth, num_actions=branch_factor, num_simulations=num_simulations, S=S ) output_file = 'search_tree.png' if not os.path.exists(output_file): print("Saving tree diagram to:", output_file) graph = convert_tree_to_graph(policy_output.search_tree) graph.draw(output_file, prog="dot") return u_best
[docs] def search_U_top1(self, key: PRNGKey, ns_results: MarginalisationData, batch_size: int, num_search: int): conditional_predictive_factory = GaussianProcessConditionalPredictiveFactory(data=self._data) acquisition_factory = ExpectedImprovementAcquisitionFactory( conditional_predictive_factory=conditional_predictive_factory ) search_key, marginalise_key = random.split(key) marginalised_acquisition = MarginalisedAcquisitionFunction( key=marginalise_key, ns_results=ns_results, acquisition_factory=acquisition_factory, S=self._S ) return BayesianOptimiser._random_search( search_key=search_key, acquisition_function=marginalised_acquisition, ndims=conditional_predictive_factory.ndims(), batch_size=batch_size, num_search=num_search )
[docs] def search_U_top2(self, key: PRNGKey, ns_results: MarginalisationData, u1: jnp.ndarray, batch_size: int, num_search: int): conditional_predictive_factory = GaussianProcessConditionalPredictiveFactory(data=self._data) acquisition_factory = TopTwoAcquisitionFactory( conditional_predictive_factory=conditional_predictive_factory, u1=u1 ) search_key, marginalise_key = random.split(key) marginalised_acquisition = MarginalisedAcquisitionFunction( key=marginalise_key, ns_results=ns_results, acquisition_factory=acquisition_factory, S=self._S ) return BayesianOptimiser._random_search( search_key=search_key, acquisition_function=marginalised_acquisition, ndims=conditional_predictive_factory.ndims(), batch_size=batch_size, num_search=num_search )
[docs] def choose_next_U_toptwo(self, key: PRNGKey, batch_size: int, num_search: int): ns_key, search_top1_key, search_top2_key, do_top2_key = random.split(key, 4) do_top2 = random.uniform(do_top2_key) < self._beta ns_results = self.posterior_solve(key=ns_key) ns_results = MarginalisationData( samples=ns_results.samples, log_dp_mean=ns_results.log_dp_mean ) # search over U-domain space for top1 (_, next_u), (acquisition, U_star) = self.search_U_top1( key=search_top1_key, ns_results=ns_results, batch_size=batch_size, num_search=num_search ) if len(U_star.shape) > 1: import pylab as plt sc = plt.scatter(U_star[:, 0], U_star[:, 1], c=acquisition, s=1, cmap='jet') plt.colorbar(sc) plt.title("Search top1 plot") plt.show() if do_top2: u1 = next_u (_, next_u), (acquisition, U_star) = self.search_U_top2( key=search_top2_key, ns_results=ns_results, u1=u1, batch_size=batch_size, num_search=num_search ) if len(U_star.shape) > 1: import pylab as plt sc = plt.scatter(U_star[:, 0], U_star[:, 1], c=acquisition, s=1, cmap='jet') plt.colorbar(sc) plt.title("Search top2 plot") plt.show() return next_u
[docs] def choose_next_U_multistep(self, key: PRNGKey, batch_size: int, max_depth: int, num_simulations: int, branch_factor: int): ns_key, search_key = random.split(key, 2) ns_results = self.posterior_solve(key=ns_key) ns_results = MarginalisationData( samples=ns_results.samples, log_dp_mean=ns_results.log_dp_mean ) next_u = BayesianOptimiser._multistep_lookahead_search( key=search_key, data=self._data, ns_results=ns_results, batch_size=batch_size, max_depth=max_depth, num_simulations=num_simulations, branch_factor=branch_factor, S=self._S ) return next_u