Bandit Optimization#
In this notebook we use boax to run a Multi-Arm Bandit experiment.
import random
from jax import config
from matplotlib import pyplot as plt
config.update("jax_enable_x64", True)
plt.style.use('bmh')
from boax.experiments import bandit
CLICK_RATES = [0.042, 0.03, 0.035, 0.038, 0.045]
def objective(variant):
return float(random.random() < CLICK_RATES[variant])
experiment = bandit(
parameters=[
{
'name': 'variant',
'type': 'choice',
'values': [0, 1, 2, 3, 4],
},
],
)
step, results = None, []
for i in range(10_000):
# Print progress
if i % 1_000 == 0:
print('.', end='')
# Retrieve next parameterizations to evaluate
step, parameterizations = experiment.next(step, results)
# Evaluate parameterizations
evaluations = [
objective(parameterization['variant'])
for parameterization in parameterizations
]
results = list(
zip(parameterizations, evaluations)
)
..........
# Predicted best
experiment.best(step)
({'variant': 4}, Array(0.04176128, dtype=float32))
# Actual best
{'variant': 4}, CLICK_RATES[4]
({'variant': 4}, 0.045)