[ ]:
import pylab as plt
from jax import random, numpy as jnp

from bojaxns.experiment import NewExperimentRequest, TrialUpdate
from bojaxns.parameter_space import ParameterSpace, Parameter, ContinuousPrior
from bojaxns.service import BayesianOptimisation

num_steps = 20


def objective(x):
    return -0.5 * jnp.sum(x ** 4 - 16 * x ** 2 + 5 * x)


def example(ndim):

    lower_bound = 39.16616 * ndim
    upper_bound = 39.16617 * ndim
    print(f"Optimal value in ({lower_bound}, {upper_bound}).")

    x_max = -2.903534

    print(f"Global optimum at {jnp.ones(ndim) * x_max}")

    parameter_space = ParameterSpace(
        parameters=[
            Parameter(
                name=f'x{i}',
                prior=ContinuousPrior(
                    lower=-5,
                    upper=5.,
                    mode=0.,
                    uncert=10.
                )
            )
            for i in range(ndim)
        ]
    )
    new_experiment_request = NewExperimentRequest(
        parameter_space=parameter_space,
        init_explore_size=10
    )
    bo_experiment = BayesianOptimisation.create_new_experiment(new_experiment=new_experiment_request)

    for i in range(num_steps):
        trial_id = bo_experiment.create_new_trial(
            key=random.PRNGKey(i),
            random_explore=False,
            beta=1.
        )
        trial = bo_experiment.get_trial(trial_id=trial_id)
        params = []
        for param_name in sorted(trial.param_values.keys()):
            param = trial.param_values[param_name]
            params.append(param.value)
        params = jnp.asarray(params)
        print(params)

        obj_val = float(objective(params))
        bo_experiment.post_measurement(
            trial_id=trial_id,
            trial_update=TrialUpdate(ref_id='a', objective_measurement=obj_val)
        )
        fig = bo_experiment.visualise()
        plt.show()
        plt.close('all')


example(5)