diff --git a/docs/environment.yml b/docs/environment.yml index 75dcffa..b11c3ec 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -17,4 +17,5 @@ dependencies: - pandas - ipykernel - matplotlib - - myst-nb \ No newline at end of file + - myst-nb + - vw-estimators diff --git a/docs/source/how_to_guides/cb_estimators.ipynb b/docs/source/how_to_guides/cb_estimators.ipynb new file mode 100644 index 0000000..fd8e884 --- /dev/null +++ b/docs/source/how_to_guides/cb_estimators.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using VW-Estimators\n", + "\n", + "`vw-estimators` is a library of off-policy estimators for various problems including contextual bandits. They can be used to evaluate target policies against a logged contextual bandit dataset. This library includes confidence bounds in addition to the estimators. In this example we process a trivial example dataset and feed the results into an IPS estimator and CressieRead confidence interval.\n", + "\n", + "`extract_label` is a function to translate how VW represents the contextual bandit label information into a more familiar form." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Optional, Tuple\n", + "import vowpal_wabbit_next as vw\n", + "from estimators.bandits import ips, cressieread\n", + "\n", + "\n", + "# VW's labels contain extra info, and are associated with each example.\n", + "# This function extracts the logical CB label from the example list.\n", + "# Assumes examples have CBLabel typed labels.\n", + "def extract_label(examples: List[vw.Example]) -> Optional[Tuple[int, float, float]]:\n", + " first_is_shared = len(examples) > 0 and examples[0].get_label().shared\n", + " for i, example in enumerate(examples):\n", + " if (label := example.get_label().label) is not None:\n", + " _, cost, prob = label\n", + " return (i - (1 if first_is_shared else 0), cost, prob)\n", + " return None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll use the following trivial input for this example. There are two actions, each identified by a single feature. We're using a StringIO so we can treat this as if we were reading it from a file with a {py:class}`vowpal_wabbit_next.TextFormatReader`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "\n", + "input = io.StringIO(\n", + " \"\"\"shared | s\n", + "0:1:0.5 | a=0\n", + "| a=1\n", + "\n", + "shared | s\n", + "| a=0\n", + "1:0:0.5 | a=1\n", + "\n", + "shared | s\n", + "0:1:0.5 | a=0\n", + "| a=1\n", + "\n", + "shared | s\n", + "| a=0\n", + "1:0:0.5 | a=1\"\"\"\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See comments for an explanation of the process." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Estimate: -0.2625000001862645\n", + "Lower bound: -1.0\n", + "Upper bound: 0.3219763298424875\n" + ] + } + ], + "source": [ + "workspace = vw.Workspace([\"--cb_explore_adf\"])\n", + "estimator = ips.Estimator()\n", + "interval = cressieread.Interval(empirical_r_bounds=True)\n", + "\n", + "estimates = []\n", + "lower = []\n", + "upper = []\n", + "\n", + "with vw.TextFormatReader(workspace, input) as reader:\n", + " for event in reader:\n", + " logged_label = extract_label(event)\n", + "\n", + " # 1. Check if this event is labelled, if not skip it\n", + " if logged_label is None:\n", + " continue\n", + "\n", + " # 2. Predict and learn on the event\n", + " pmf = workspace.predict_then_learn_one(event)\n", + "\n", + " # 3. Extract the logged cost and the probability of choosing it according to the logged policy\n", + " logged_action_0_based, logged_cost, logged_prob = logged_label\n", + "\n", + " # 4. Get the probability of choosing the logged action according to the target policy\n", + " prediction_prob = next(x for i, x in pmf if i == logged_action_0_based)\n", + "\n", + " # 5. Feed these values into the estimator and confidence interval\n", + " # Note: These operate with rewards so we multiply cost by -1 to convert to reward\n", + " estimator.add_example(logged_prob, logged_cost * -1, prediction_prob)\n", + " interval.add_example(logged_prob, logged_cost * -1, prediction_prob)\n", + "\n", + "print(f\"Estimate: {estimator.get()}\")\n", + "bounds = interval.get()\n", + "print(f\"Lower bound: {bounds[0]}\")\n", + "print(f\"Upper bound: {bounds[1]}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pynextdev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "3c562d20ef5aa9e83e98cc981b6703965c3968967e80b00b9de18f40ae75cc1c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/index.rst b/docs/source/index.rst index 174bad7..2769069 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Installation how_to_guides/save_load_models.ipynb how_to_guides/cache_format.ipynb how_to_guides/inspect_model_weights.ipynb + how_to_guides/cb_estimators.ipynb .. toctree:: :caption: Tutorials