Source code for bojaxns.parameter_space

from functools import partial
from typing import Union, Literal, List, Generator, Annotated

import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from chex import PRNGKey
from jax import random, jit
from jax._src.lax.control_flow import while_loop
from jaxns import Categorical, Prior
from jaxns.internals.types import float_type, int_type
from pydantic import BaseModel, Field, validator, confloat

from bojaxns.common import FloatValue, IntValue, ParamValues, UValue
from bojaxns.utils import build_example

tfpd = tfp.distributions

__all__ = [
    'ContinuousPrior',
    'IntegerPrior',
    'CategoricalPrior',
    'Parameter',
    'ParameterSpace',
    'build_prior_model'
]


[docs] class ContinuousPrior(BaseModel):
[docs] type: Literal['continuous_prior'] = 'continuous_prior'
[docs] lower: float = Field( description="The greatest lower bound of interval. Inclusive.", example=0.1 )
[docs] upper: float = Field( description="The least upper bound of interval. Inclusive.", example=5.5 )
[docs] mode: float = Field( description="The mode of the prior.", example=2.5 )
[docs] uncert: confloat(gt=0.) = Field( description="The uncertainty of the prior. Set to np.inf for the uniform prior over (lower, upper).", example=2. )
[docs] class IntegerPrior(BaseModel):
[docs] type: Literal['integer_prior'] = 'integer_prior'
[docs] lower: int = Field( description="The greatest lower bound of interval. Inclusive.", example=0 )
[docs] upper: int = Field( description="The least upper bound of interval. Inclusive.", example=5 )
[docs] mode: float = Field( description="The mode of the prior. Can be a float.", example=2.5 )
[docs] uncert: confloat(gt=0.) = Field( description="The uncertainty of the prior. Set to np.inf for the uniform prior over (lower, upper). Can be a float.", example=2. )
[docs] class CategoricalPrior(BaseModel):
[docs] type: Literal['categorical_prior'] = 'categorical_prior'
[docs] probs: List[confloat(ge=0.)] = Field( description="The unnormalised probabilities of categories. Must be >= 0, need not be normalised.", example=[0.1, 0.3, 0.6] )
ParamPrior = Annotated[ Union[ContinuousPrior, IntegerPrior, CategoricalPrior], Field( description='The parameter prior, which defines the domain.', discriminator='type' ) ]
[docs] class Parameter(BaseModel):
[docs] name: str = Field( description="The name of the parameter", example='price' )
[docs] prior: ParamPrior
[docs] class ParameterSpace(BaseModel):
[docs] parameters: List[Parameter] = Field( description='The parameters of the problem.', example=[ Parameter( name='continuous', prior=build_example(ContinuousPrior) ), Parameter( name='integers', prior=build_example(IntegerPrior) ), Parameter( name='categorical', prior=build_example(CategoricalPrior) ) ] )
@validator('parameters', always=True)
[docs] def unique_parameters(cls, value): names = list(map(lambda param: param.name, value)) if len(names) != len(set(names)): raise ValueError(f"parameter names must be unique. Got {names}.") return value
@partial(jit, static_argnames=['parametrisation']) def sample_U_categorical(key: PRNGKey, logits: jnp.ndarray, target_cat: jnp.ndarray, parametrisation: Literal['cdf', 'gumbel_max']) -> jnp.ndarray: prior = Categorical(parametrisation=parametrisation, logits=logits) def _iter(state): (key, U, cat) = state key, sample_key = random.split(key, 2) U = random.uniform(sample_key, shape=(prior.base_ndims,), dtype=float_type) cat = prior.forward(U) cat = cat.reshape(()) return (key, U, cat) key, sample_key = random.split(key, 2) U = random.uniform(sample_key, shape=(prior.base_ndims,), dtype=float_type) cat = prior.forward(U) if cat.size != 1: raise ValueError(f"Category shape must be flat, got {cat.shape}") cat = cat.reshape(()) (_, U, _) = while_loop(lambda state: state[2] != target_cat, _iter, (key, U, cat)) return U.reshape((-1,)) def inverse_transform_param(key: PRNGKey, param: Parameter, param_value: Union[FloatValue, IntValue]) -> List[float]: prior = param.prior if isinstance(prior, ContinuousPrior): underlying_dist = tfpd.TruncatedNormal( loc=jnp.asarray(prior.mode, float_type), scale=jnp.asarray(prior.uncert, float_type), low=jnp.asarray(prior.lower, float_type), high=jnp.asarray(prior.upper, float_type) ) return underlying_dist.cdf(param_value.value).reshape((-1,)).tolist() elif isinstance(prior, IntegerPrior): underlying_dist = tfpd.Normal(loc=jnp.asarray(prior.mode, float_type), scale=jnp.asarray(prior.uncert, float_type)) int_options = jnp.arange(prior.lower, prior.upper + 1, dtype=float_type) logits = underlying_dist.log_prob(int_options) return sample_U_categorical(key=key, logits=logits, target_cat=jnp.asarray(param_value.value, int_type), parametrisation='cdf').tolist() elif isinstance(prior, CategoricalPrior): return sample_U_categorical(key=key, logits=jnp.log(jnp.asarray(prior.probs)), target_cat=jnp.asarray(param_value.value, int_type), parametrisation='gumbel_max').tolist() else: raise ValueError(f"Invalid prior {prior}") def sample_U_value(key: PRNGKey, param_space: ParameterSpace, param_values: ParamValues) -> UValue: U = [] for param in param_space.parameters: key, sample_key = random.split(key, 2) U.extend(inverse_transform_param(key=sample_key, param=param, param_value=param_values[param.name])) return U def translate_parameter(param: Parameter) -> Generator[Prior, jnp.ndarray, jnp.ndarray]: prior = param.prior if isinstance(prior, ContinuousPrior): underlying_dist = tfpd.TruncatedNormal( loc=jnp.asarray(prior.mode, float_type), scale=jnp.asarray(prior.uncert, float_type), low=jnp.asarray(prior.lower, float_type), high=jnp.asarray(prior.upper, float_type) ) param_value = yield Prior(dist_or_value=underlying_dist, name=param.name) return param_value elif isinstance(prior, IntegerPrior): underlying_dist = tfpd.Normal(loc=jnp.asarray(prior.mode, float_type), scale=jnp.asarray(prior.uncert, float_type)) int_options = jnp.arange(prior.lower, prior.upper + 1, dtype=float_type) logits = underlying_dist.log_prob(int_options) param_value_idx = yield Categorical(parametrisation='cdf', logits=logits) param_value = yield Prior(dist_or_value=param_value_idx + prior.lower, name=param.name) return param_value elif isinstance(prior, CategoricalPrior): probs = jnp.asarray(prior.probs) probs /= jnp.sum(probs) param_value = yield Categorical(parametrisation='gumbel_max', probs=probs, name=param.name) return param_value else: raise ValueError(f"Invalid prior {prior}")
[docs] def build_prior_model(parameter_space: ParameterSpace): """ Constructs a prior model given the parameter space. Args: parameter_space: Returns: """ def prior_model(): param_values = [] for parameter in parameter_space.parameters: x = yield from translate_parameter(param=parameter) param_values.append(x) return tuple(param_values) return prior_model