from functools import partial
from typing import Optional, Sequence, Tuple, Callable, NamedTuple
import chex
import jax
import jax.numpy as jnp
import mctx
from jax import tree_map
from mctx import PolicyOutput, GumbelMuZeroExtraData, RootFnOutput, RecurrentFn
from bojaxns.base import MarginalisedConditionalPredictive, MarginalisationData, MarginalisedAcquisitionFunction
from bojaxns.gaussian_process_formulation.distribution_math import GaussianProcessData, \
GaussianProcessConditionalPredictiveFactory, ExpectedImprovementAcquisitionFactory
[docs]
def convert_tree_to_graph(
tree: mctx.Tree,
action_labels: Optional[Sequence[str]] = None,
batch_index: int = 0
):
"""Converts a search tree into a Graphviz graph.
Args:
tree: A `Tree` containing a batch of search data.
action_labels: Optional labels for edges, defaults to the action index.
batch_index: Index of the batch element to plot.
Returns:
A Graphviz graph representation of `tree`.
"""
import pygraphviz
chex.assert_rank(tree.node_values, 2)
batch_size = tree.node_values.shape[0]
if action_labels is None:
action_labels = range(tree.num_actions)
elif len(action_labels) != tree.num_actions:
raise ValueError(
f"action_labels {action_labels} has the wrong number of actions "
f"({len(action_labels)}). "
f"Expecting {tree.num_actions}.")
def node_to_str(node_i, reward=0, discount=1):
return (f"{node_i}\n"
f"Reward: {reward:.2f}\n"
f"Discount: {discount:.2f}\n"
f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
f"Visits: {tree.node_visits[batch_index, node_i]}\n")
def edge_to_str(node_i, a_i):
node_index = jnp.full([batch_size], node_i)
probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i])
return (f"{action_labels[a_i]}\n"
f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations
f"p: {probs[a_i]:.2f}\n")
graph = pygraphviz.AGraph(directed=True)
# Add root
graph.add_node(0, label=node_to_str(node_i=0), color="green")
# Add all other nodes and connect them up.
for node_i in range(tree.num_simulations):
for a_i in range(tree.num_actions):
# Index of children, or -1 if not expanded
children_i = tree.children_index[batch_index, node_i, a_i]
if children_i >= 0:
graph.add_node(
children_i,
label=node_to_str(
node_i=children_i,
reward=tree.children_rewards[batch_index, node_i, a_i],
discount=tree.children_discounts[batch_index, node_i, a_i]),
color="red")
graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))
return graph
[docs]
class MultiLookAheadState(NamedTuple):
[docs]
data: GaussianProcessData
[docs]
RewardFnType = Callable[[MultiLookAheadState], chex.Array]
@partial(jax.jit, static_argnames=['batch_size', 'max_depth', 'num_actions', 'num_simulations', 'S'])
[docs]
def run_multi_lookahead(rng_key: chex.PRNGKey, data: GaussianProcessData, ns_results: MarginalisationData,
batch_size: int, max_depth: int, num_actions: int, num_simulations: int, S: int) -> \
Tuple[chex.Array, PolicyOutput[GumbelMuZeroExtraData]]:
recurrent_key, search_key, action_key = jax.random.split(rng_key, 3)
U_ndims = data.U.shape[-1]
U_options = jax.random.uniform(key=action_key, shape=(num_actions, U_ndims))
root, recurrent_fn = _make_batched_env_model(
# Using batch_size=2 to test the batched search.
key=recurrent_key,
batch_size=batch_size,
max_depth=max_depth,
num_actions=num_actions,
S=S,
data=data,
ns_results=ns_results,
U_options=U_options,
)
# Running the search.
policy_output = mctx.gumbel_muzero_policy(
params=(),
rng_key=search_key,
root=root,
recurrent_fn=recurrent_fn,
num_simulations=num_simulations,
max_depth=max_depth,
max_num_considered_actions=num_actions,
)
@jax.vmap
def _extract_value(batch_index):
selected_action = policy_output.action[batch_index]
q_value = policy_output.search_tree.summary().qvalues[
batch_index, selected_action]
return selected_action, q_value
selected_actions, q_values = _extract_value(jnp.arange(batch_size, dtype=jnp.int32))
best_idx = jnp.argmax(q_values)
best_action = selected_actions[best_idx]
u_best = U_options[best_action]
return u_best, policy_output
def _make_batched_env_model(
key: chex.PRNGKey,
batch_size: int,
num_actions: int,
S: int,
max_depth: int,
data: GaussianProcessData,
ns_results: MarginalisationData,
U_options: chex.Array) -> Tuple[RootFnOutput, RecurrentFn]:
"""Returns a batched `(root, recurrent_fn)`."""
reward_key, marginalise_key = jax.random.split(key, 2)
def _duplicate_batch(x):
return jnp.repeat(x[None, ...], batch_size, axis=0)
def _extend(x, value, extra_len):
extra = jnp.full((extra_len,) + x.shape[1:], value)
return jnp.concatenate([x, extra], axis=0)
init_data_size = data.Y.size
# Need up to depth-1
init_data = GaussianProcessData(
U=_extend(data.U, 0., max_depth - 1),
Y=_extend(data.Y, 0., max_depth - 1),
Y_var=_extend(data.Y_var, jnp.inf, max_depth - 1),
sample_size=_extend(data.sample_size, jnp.inf, max_depth - 1),
)
init_state = MultiLookAheadState(
data=init_data,
depth=jnp.asarray(0)
)
init_state = tree_map(_duplicate_batch, init_state)
prior_logits = jnp.full([batch_size, num_actions], 0.)
# The approximate expected value of sum_i discount^i * R_i
# Zero if more expected improvement.
approximate_value = jnp.full([batch_size], 0.)
root = mctx.RootFnOutput(
prior_logits=prior_logits,
value=approximate_value,
embedding=init_state
)
def recurrent_fn(params, rng_key, action: chex.Array, embedding: MultiLookAheadState) -> Tuple[
mctx.RecurrentFnOutput, MultiLookAheadState]:
del params, rng_key
chex.assert_shape(action, [batch_size])
chex.assert_shape(embedding.depth, [batch_size])
def _single_batch(_action: chex.Array, _state: MultiLookAheadState) -> Tuple[
mctx.RecurrentFnOutput, MultiLookAheadState]:
conditional_predictive_factory = GaussianProcessConditionalPredictiveFactory(data=_state.data)
conditional_predictive = MarginalisedConditionalPredictive(
key=marginalise_key,
ns_results=ns_results,
conditional_predictive_factory=conditional_predictive_factory,
S=S
)
acquisition_factory = ExpectedImprovementAcquisitionFactory(
conditional_predictive_factory=conditional_predictive_factory
)
acquisition = MarginalisedAcquisitionFunction(
key=reward_key,
ns_results=ns_results,
acquisition_factory=acquisition_factory,
S=S
)
# Compute EI(f_i, f_i-1)
u_star = U_options[_action]
reward = acquisition(u_star)
post_mu_s, post_var_s = conditional_predictive(U_star=u_star[None, :]) # should be optimised away
next_Y = jnp.reshape(post_mu_s, ())
next_Y_var = jnp.reshape(post_var_s, ())
replace_idx = init_data_size + _state.depth
next_data = GaussianProcessData(
U=_state.data.U.at[replace_idx, :].set(u_star),
Y=_state.data.Y.at[replace_idx].set(next_Y),
Y_var=_state.data.Y_var.at[replace_idx].set(next_Y_var),
sample_size=_state.data.sample_size # stays the same inf
)
next_depth = _state.depth + 1
next_state = MultiLookAheadState(
data=next_data,
depth=next_depth
)
discount = jnp.asarray(1.)
prior_logits = jnp.full([num_actions], 0.)
approximate_value = reward # jnp.asarray(0.)
recurrent_fn_output = mctx.RecurrentFnOutput(
reward=reward,
discount=discount,
prior_logits=prior_logits,
value=approximate_value
)
return recurrent_fn_output, next_state
return jax.vmap(_single_batch)(action, embedding)
return root, recurrent_fn
[docs]
def static_fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val