Bandit Optimization

Bandit Optimization#

In this notebook we use boax to run a Multi-Arm Bandit experiment.

import random

from jax import config
from matplotlib import pyplot as plt

config.update("jax_enable_x64", True)
plt.style.use('bmh')

from boax.experiments import bandit
CLICK_RATES = [0.042, 0.03, 0.035, 0.038, 0.045]
def objective(variant):
    return float(random.random() < CLICK_RATES[variant])
experiment = bandit(
    parameters=[
        {
            'name': 'variant',
            'type': 'choice',
            'values': [0, 1, 2, 3, 4],
        },
    ],
)
step, results = None, []
for i in range(10_000):
    # Print progress
    if i % 1_000 == 0:
        print('.', end='')

    # Retrieve next parameterizations to evaluate
    step, parameterizations = experiment.next(step, results)

    # Evaluate parameterizations
    evaluations = [
        objective(parameterization['variant'])
        for parameterization in parameterizations
    ]
    
    results = list(
        zip(parameterizations, evaluations)
    )
..........
# Predicted best
experiment.best(step)
({'variant': 4}, Array(0.04176128, dtype=float32))
# Actual best
{'variant': 4}, CLICK_RATES[4]
({'variant': 4}, 0.045)