import pylab as plt
from jax import random, numpy as jnp
from bojaxns.experiment import NewExperimentRequest, TrialUpdate
from bojaxns.parameter_space import ParameterSpace, Parameter, ContinuousPrior
from bojaxns.service import BayesianOptimisation
num_steps = 20
def objective(x):
return -0.5 * jnp.sum(x ** 4 - 16 * x ** 2 + 5 * x)
def example(ndim):
lower_bound = 39.16616 * ndim
upper_bound = 39.16617 * ndim
print(f"Optimal value in ({lower_bound}, {upper_bound}).")
x_max = -2.903534
print(f"Global optimum at {jnp.ones(ndim) * x_max}")
parameter_space = ParameterSpace(
parameters=[
Parameter(
name=f'x{i}',
prior=ContinuousPrior(
lower=-5,
upper=5.,
mode=0.,
uncert=10.
)
)
for i in range(ndim)
]
)
new_experiment_request = NewExperimentRequest(
parameter_space=parameter_space,
init_explore_size=10
)
bo_experiment = BayesianOptimisation.create_new_experiment(new_experiment=new_experiment_request)
for i in range(num_steps):
trial_id = bo_experiment.create_new_trial(
key=random.PRNGKey(i),
random_explore=False,
beta=1.
)
trial = bo_experiment.get_trial(trial_id=trial_id)
params = []
for param_name in sorted(trial.param_values.keys()):
param = trial.param_values[param_name]
params.append(param.value)
params = jnp.asarray(params)
print(params)
obj_val = float(objective(params))
bo_experiment.post_measurement(
trial_id=trial_id,
trial_update=TrialUpdate(ref_id='a', objective_measurement=obj_val)
)
fig = bo_experiment.visualise()
plt.show()
plt.close('all')
example(5)