Source code for bojaxns.service

import logging

import numpy as np
import pylab as plt
from jax import random, numpy as jnp
from jax._src.random import PRNGKey
from jaxns.framework.ops import parse_prior, transform
from matplotlib import dates as mdates

from bojaxns.common import FloatValue, IntValue, ParamValues
from bojaxns.experiment import OptimisationExperiment, NewExperimentRequest, Trial, TrialUpdate
from bojaxns.gaussian_process_formulation.bayesian_optimiser import BayesianOptimiser
from bojaxns.parameter_space import build_prior_model, ContinuousPrior, IntegerPrior, CategoricalPrior, sample_U_value
from bojaxns.utils import latin_hypercube

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

__all__ = [
    'InvalidTrial',
    'BayesianOptimisation'
]


[docs] class InvalidTrial(Exception): pass
[docs] class BayesianOptimisation: def __init__(self, experiment: OptimisationExperiment): self._experiment = experiment @property
[docs] def experiment(self): return self._experiment
@classmethod
[docs] def create_new_experiment(cls, new_experiment: NewExperimentRequest) -> 'BayesianOptimisation': experiment = OptimisationExperiment(parameter_space=new_experiment.parameter_space) prior_model = build_prior_model(experiment.parameter_space) U_placeholder, _ = parse_prior(prior_model=prior_model) U_dims = U_placeholder.size U_seeds = latin_hypercube( seed=42, num_samples=new_experiment.init_explore_size, num_dim=U_dims ) # Add trials for seeds. for U_seed in U_seeds: trial = BayesianOptimisation._create_trial( experiment=experiment, U=U_seed, prior_model=prior_model ) experiment.trials[trial.trial_id] = trial return cls(experiment=experiment)
@staticmethod def _create_trial(experiment: OptimisationExperiment, U: jnp.ndarray, prior_model) -> Trial: prior_sample = transform(U=U, prior_model=prior_model) param_values = {} for param in experiment.parameter_space.parameters: val = prior_sample[param.name] if isinstance(param.prior, ContinuousPrior): param_values[param.name] = FloatValue(value=float(val)) continue if isinstance(param.prior, IntegerPrior): param_values[param.name] = IntValue(value=int(val)) continue if isinstance(param.prior, CategoricalPrior): param_values[param.name] = IntValue(value=int(val)) continue trial = Trial(param_values=param_values, U_value=U.tolist()) return trial
[docs] def add_trial_from_data(self, key: PRNGKey, param_values: ParamValues) -> str: U = sample_U_value(key=key, param_space=self._experiment.parameter_space, param_values=param_values) trial = Trial(param_values=param_values, U_value=U) self._experiment.trials[trial.trial_id] = trial return trial.trial_id
[docs] def create_new_trial(self, key: PRNGKey, random_explore: bool = False, beta: float = 0.5) -> str: # Go through trials, and find one with unfilled values. Give that one. for trial in sorted(self._experiment.trials.values(), key=lambda t: t.create_dt): if len(trial.trial_updates) == 0: return trial.trial_id prior_model = build_prior_model(self.experiment.parameter_space) if random_explore: U_placeholder, _ = parse_prior(prior_model=prior_model) U_dims = U_placeholder.size U = random.uniform(key, shape=(U_dims,)) trial = BayesianOptimisation._create_trial( experiment=self.experiment, U=U, prior_model=prior_model ) self._experiment.trials[trial.trial_id] = trial return trial.trial_id # get new trial parameter from bojaxns bo = BayesianOptimiser(experiment=self._experiment, beta=beta, S=128) # U = bo.choose_next_U_multistep( # key=key, # batch_size=1, # max_depth=2, # num_simulations=1000, # branch_factor=100 # ) U = bo.choose_next_U_toptwo( key=key, batch_size=10, num_search=100000 ) trial = BayesianOptimisation._create_trial( experiment=self.experiment, U=U, prior_model=prior_model ) self._experiment.trials[trial.trial_id] = trial return trial.trial_id
[docs] def get_trial(self, trial_id: str) -> Trial: if trial_id not in self._experiment.trials: raise InvalidTrial(trial_id) return self._experiment.trials[trial_id]
[docs] def delete_trial(self, trial_id: str): if trial_id not in self._experiment.trials: raise InvalidTrial(trial_id) del self._experiment.trials[trial_id]
[docs] def post_measurement(self, trial_id: str, trial_update: TrialUpdate): trial = self.get_trial(trial_id=trial_id) if (trial_update.ref_id in trial.trial_updates) and ( trial_update.objective_measurement == trial.trial_updates[trial_update.ref_id].objective_measurement): return trial.trial_updates[trial_update.ref_id] = trial_update
[docs] def trial_size(self, trial_id: str): trial = self.get_trial(trial_id=trial_id) return len(trial.trial_updates)
[docs] def visualise(self, main_color="#7e97bf", grid_color="#969396", ) -> plt.Figure: """ Constructs a visual breakdown of condition Args: main_color: color of main axes grid_color: color of grid Returns: a pylab Figure Raises: NotEnoughData if not enough to compute a breakdown """ # Plots scatter of trial outcomes over time. # Highlight, best. # For each trial use a colored line, with error bars that scale 1/sqrt(S) series = [] for trial in self._experiment.trials.values(): if len(trial.trial_updates) == 0: continue x, y, n = [], [], [] for trial_update in sorted(trial.trial_updates.values(), key=lambda tu: tu.measurement_dt): x.append(trial_update.measurement_dt) if len(y) == 0: y.append(trial_update.objective_measurement) n.append(1) else: y.append(y[-1] + trial_update.objective_measurement) n.append(n[-1] + 1) y = list(map(lambda _y, _n: _y / _n, y, n)) if len(y) < 3: y_std = [0.] * len(y) else: _mu = jnp.mean(jnp.asarray(y)) _y_std = jnp.abs(y[0] - _mu) y_std = _y_std / jnp.sqrt(jnp.asarray(n)) y_std = y_std.tolist() series.append((x, y, y_std)) if len(series) == 0: raise RuntimeError("Nothing to visualise. Provide data to a trial first.") min_dt = min(min(x) for (x, y, y_std) in series) max_dt = max(max(x) for (x, y, y_std) in series) fig_width = 6 fig_height = fig_width fig, ax = plt.subplots(figsize=(fig_height, fig_width), facecolor='None') ax.set_facecolor('none') np.random.seed(42) # deterministic colors for x, y, y_std in series: color = [np.random.uniform(), np.random.uniform(), np.random.uniform(), 1.] ax.plot(x, y, c=color) color[-1] = 0.2 ax.errorbar(x, y, y_std, fmt='-o', color=color ) agg_series = [(x[-1], y[-1], y_std[-1]) for (x, y, y_std) in series] agg_series = np.asarray(agg_series) idx_max = np.argmax(agg_series[:, 1]) ax.scatter(agg_series[idx_max, 0], agg_series[idx_max, 1], s=100, fc='none', ec='black', marker='o', label=f"Best {agg_series[idx_max, 1]}") # min_date_num = mdates.date2num(min_dt) # max_date_num = mdates.date2num(max_dt) # # interval = max((max_date_num - min_date_num) // 6, 1) # # ax.set_xticks(np.arange(min_date_num, max_date_num, interval)) # xticks = [] # ax.set_xticks(mdates.date2num(datetime.combine(today, time()))) # restrict x lim ax.set_xlim( mdates.date2num(min_dt), mdates.date2num(max_dt) ) # ax.set_ylim(0, 100) date_formatter = mdates.DateFormatter('%a, %-d %b') # Customize the format as per your preference ax.xaxis.set_major_formatter(date_formatter) ax.set_title(f"Trial progression", color=main_color, fontsize=8) # Rotate value labels ax.tick_params(axis='x', rotation=45, labelsize=6) # Add legend ax.legend(loc='best', prop={'size': 6}, framealpha=0.1) # ax.legend(loc='upper center', prop={'size': 6}, framealpha=0.25, bbox_to_anchor=(0.5, -0.05)) # Set xlim and ylim tight to data points ax.margins(0.01) # Invisible top, right ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Set axis-colors ax.spines['left'].set_color(main_color) ax.spines['bottom'].set_color(main_color) ax.tick_params(axis='x', colors=main_color) ax.tick_params(axis='y', colors=main_color) # grid ax.grid(axis='y', linestyle='dashed', color=grid_color) # Make tight fig.tight_layout() return fig