gaussian_process_formulation

bojaxns.gaussian_process_formulation

Submodules

Package Contents

class AbstractAcquisition[source]

A class that represents any acquisition function. All acquisition functions take a point in the U-domain and returns a metric that gives a proxy as to how valuable it would be to try that point. All acquisition values only make sense relatively.

abstract __call__(u_star)[source]
Parameters:

u_star (jax.numpy.ndarray) –

class MarginalisedAcquisitionFunction(key, ns_results, acquisition_factory, S)[source]

Bases: AbstractAcquisition

Class that represents a marginalisation of an acquisition function over samples.

Parameters:
__call__(u_star)[source]
Parameters:

u_star (jax.numpy.ndarray) –

class MarginalisationData[source]

Bases: NamedTuple

samples: Dict[str, chex.Array]
log_dp_mean: chex.Array
class OptimisationExperiment[source]

Bases: pydantic.BaseModel

experiment_id: str
parameter_space: bojaxns.parameter_space.ParameterSpace
trials: Dict[str, Trial]
ensure_parameters_match_space(value, values)[source]
class GaussianProcessData[source]

Bases: NamedTuple

U: jax.numpy.ndarray
Y: jax.numpy.ndarray
Y_var: jax.numpy.ndarray
sample_size: jax.numpy.ndarray
class GaussianProcessConditionalPredictiveFactory(data)[source]

Bases: bojaxns.base.ConditionalPredictiveFactory

Parameters:

data (GaussianProcessData) –

ndims()[source]
build_prior_model()[source]
psd_kernels()[source]
Return type:

List[Type[tensorflow_probability.substrates.jax.math.psd_kernels.PositiveSemidefiniteKernel]]

__call__(**samples)[source]
Return type:

GaussianProcessConditionalPredictive

class ExpectedImprovementAcquisitionFactory(conditional_predictive_factory)[source]

Bases: bojaxns.base.AcquisitionFactory

Parameters:

conditional_predictive_factory (GaussianProcessConditionalPredictiveFactory) –

__call__(**sample)[source]
Return type:

bojaxns.base.AbstractAcquisition

class TopTwoAcquisitionFactory(conditional_predictive_factory, u1)[source]

Bases: bojaxns.base.AcquisitionFactory

Parameters:
__call__(**sample)[source]
Return type:

bojaxns.base.AbstractAcquisition

run_multi_lookahead(rng_key, data, ns_results, batch_size, max_depth, num_actions, num_simulations, S)[source]
Parameters:
Return type:

Tuple[chex.Array, mctx.PolicyOutput[mctx.GumbelMuZeroExtraData]]

convert_tree_to_graph(tree, action_labels=None, batch_index=0)[source]

Converts a search tree into a Graphviz graph.

Parameters:
  • tree (mctx.Tree) – A Tree containing a batch of search data.

  • action_labels (Optional[Sequence[str]]) – Optional labels for edges, defaults to the action index.

  • batch_index (int) – Index of the batch element to plot.

Returns:

A Graphviz graph representation of tree.

tfpb[source]
class BayesianOptimiser(experiment, num_parallel_solvers=1, beta=0.5, S=512)[source]
Parameters:
posterior_solve(key)[source]
Parameters:

key (chex.PRNGKey) –

Return type:

jaxns.internals.types.NestedSamplerResults

search_U_top1(key, ns_results, batch_size, num_search)[source]
Parameters:
search_U_top2(key, ns_results, u1, batch_size, num_search)[source]
Parameters:
choose_next_U_toptwo(key, batch_size, num_search)[source]
Parameters:
  • key (chex.PRNGKey) –

  • batch_size (int) –

  • num_search (int) –

choose_next_U_multistep(key, batch_size, max_depth, num_simulations, branch_factor)[source]
Parameters:
  • key (chex.PRNGKey) –

  • batch_size (int) –

  • max_depth (int) –

  • num_simulations (int) –

  • branch_factor (int) –