from datetime import datetime
from typing import Dict
from uuid import uuid4
from pydantic import BaseModel, Field, validator, conint
from bojaxns.common import FloatValue, ParamValues, UValue
from bojaxns.parameter_space import ParameterSpace
from bojaxns.utils import current_utc, build_example
__all__ = [
'Trial',
'TrialUpdate',
'OptimisationExperiment',
'NewExperimentRequest',
]
[docs]
class TrialUpdate(BaseModel):
[docs]
ref_id: str = Field(
description="An identifier of the measurement, e.g. user UUID.",
example=str(uuid4())
)
[docs]
measurement_dt: datetime = Field(
default_factory=current_utc,
description='The datetime the objective_measurement was determined.',
example=current_utc()
)
[docs]
objective_measurement: float = Field(
description="The measurement of trial objective function.",
example=1.
)
[docs]
class Trial(BaseModel):
[docs]
trial_id: str = Field(
default_factory=lambda: str(uuid4()),
description='UUID for this trial.',
example=str(uuid4())
)
[docs]
create_dt: datetime = Field(
default_factory=current_utc,
description='The datetime the param_value was determined.',
example=current_utc()
)
[docs]
param_values: ParamValues = Field(
description="The parameter mapping for trial.",
example={'price': FloatValue(value=1.)}
)
[docs]
U_value: UValue = Field(
description="The U-space value of parameters.",
example=[0.2]
)
[docs]
trial_updates: Dict[str, TrialUpdate] = Field(
default_factory=dict,
description="The measurement of trial updates.",
example={"124": build_example(TrialUpdate)}
)
[docs]
class OptimisationExperiment(BaseModel):
[docs]
experiment_id: str = Field(
default_factory=lambda: str(uuid4()),
description='UUID for this experiment.',
example=str(uuid4())
)
[docs]
parameter_space: ParameterSpace = Field(
description='The parameter space that defines this experiment.',
example=build_example(ParameterSpace)
)
[docs]
trials: Dict[str, Trial] = Field(
default_factory=dict,
description="The mapping of trials that define the sequence of this experiment.",
example={'12345': build_example(Trial)}
)
@validator('trials', always=True)
[docs]
def ensure_parameters_match_space(cls, value, values):
parameter_space: ParameterSpace = values['parameter_space']
names = list(map(lambda param: param.name, parameter_space.parameters))
for trial_id in value:
trial: Trial = value[trial_id]
_names = list(trial.param_values)
if set(_names) != set(names):
raise ValueError(f"trial param_values {_names} don't match param space {names}.")
return value
[docs]
class NewExperimentRequest(BaseModel):
[docs]
parameter_space: ParameterSpace = Field(
description='The parameter space that defines this experiment.',
example=build_example(ParameterSpace)
)
[docs]
init_explore_size: conint(ge=1)