Source code for bojaxns.experiment

from datetime import datetime
from typing import Dict
from uuid import uuid4

from pydantic import BaseModel, Field, validator, conint

from bojaxns.common import FloatValue, ParamValues, UValue
from bojaxns.parameter_space import ParameterSpace
from bojaxns.utils import current_utc, build_example

__all__ = [
    'Trial',
    'TrialUpdate',
    'OptimisationExperiment',
    'NewExperimentRequest',
]


[docs] class TrialUpdate(BaseModel):
[docs] ref_id: str = Field( description="An identifier of the measurement, e.g. user UUID.", example=str(uuid4()) )
[docs] measurement_dt: datetime = Field( default_factory=current_utc, description='The datetime the objective_measurement was determined.', example=current_utc() )
[docs] objective_measurement: float = Field( description="The measurement of trial objective function.", example=1. )
[docs] class Trial(BaseModel):
[docs] trial_id: str = Field( default_factory=lambda: str(uuid4()), description='UUID for this trial.', example=str(uuid4()) )
[docs] create_dt: datetime = Field( default_factory=current_utc, description='The datetime the param_value was determined.', example=current_utc() )
[docs] param_values: ParamValues = Field( description="The parameter mapping for trial.", example={'price': FloatValue(value=1.)} )
[docs] U_value: UValue = Field( description="The U-space value of parameters.", example=[0.2] )
[docs] trial_updates: Dict[str, TrialUpdate] = Field( default_factory=dict, description="The measurement of trial updates.", example={"124": build_example(TrialUpdate)} )
[docs] class OptimisationExperiment(BaseModel):
[docs] experiment_id: str = Field( default_factory=lambda: str(uuid4()), description='UUID for this experiment.', example=str(uuid4()) )
[docs] parameter_space: ParameterSpace = Field( description='The parameter space that defines this experiment.', example=build_example(ParameterSpace) )
[docs] trials: Dict[str, Trial] = Field( default_factory=dict, description="The mapping of trials that define the sequence of this experiment.", example={'12345': build_example(Trial)} )
@validator('trials', always=True)
[docs] def ensure_parameters_match_space(cls, value, values): parameter_space: ParameterSpace = values['parameter_space'] names = list(map(lambda param: param.name, parameter_space.parameters)) for trial_id in value: trial: Trial = value[trial_id] _names = list(trial.param_values) if set(_names) != set(names): raise ValueError(f"trial param_values {_names} don't match param space {names}.") return value
[docs] class NewExperimentRequest(BaseModel):
[docs] parameter_space: ParameterSpace = Field( description='The parameter space that defines this experiment.', example=build_example(ParameterSpace) )
[docs] init_explore_size: conint(ge=1)