Skip to content

Commit ccf3d96

Browse files
authored
feat: implement cost sensitive labels (#55)
1 parent d947e40 commit ccf3d96

File tree

6 files changed

+160
-8
lines changed

6 files changed

+160
-8
lines changed

src/main.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "vw/core/cache.h"
55
#include "vw/core/cb.h"
66
#include "vw/core/constant.h"
7+
#include "vw/core/cost_sensitive.h"
78
#include "vw/core/example.h"
89
#include "vw/core/global_data.h"
910
#include "vw/core/guard.h"
@@ -645,6 +646,15 @@ struct py_simple_label
645646
float initial;
646647
};
647648

649+
bool is_shared(const VW::cs_label& label)
650+
{
651+
const auto& costs = label.costs;
652+
if (costs.size() != 1) { return false; }
653+
if (costs[0].class_index != 0) { return false; }
654+
if (costs[0].x != -FLT_MAX) { return false; }
655+
return true;
656+
}
657+
648658
} // namespace
649659

650660
PYBIND11_MODULE(_core, m)
@@ -790,6 +800,9 @@ PYBIND11_MODULE(_core, m)
790800
py::kw_only(), py::arg("label") = py::none(), py::arg("weight") = 1.f, py::arg("shared") = false, R"docstring(
791801
A label representing a contextual bandit problem.
792802
803+
.. note::
804+
Currently the label can only contain 1 or 0 cb costs. There is a mode in vw for CB (non-adf) that allows for multiple cb_classes per example, but it is not currently accessible via direct label access. If creating examples/labels from parsed input it should still work as expected. If you need this feature, please open an issue on the github repo.
805+
793806
Args:
794807
label (Optional[Union[Tuple[float, float], Tuple[int, float, float]]): This is (action, cost, probability). The same rules as VW apply for if the action can be left out of the tuple.
795808
weight (float): The weight of the example.
@@ -841,6 +854,63 @@ PYBIND11_MODULE(_core, m)
841854
The weight of the example.
842855
)docstring");
843856

857+
py::class_<VW::cs_label>(m, "CSLabel")
858+
.def(py::init(
859+
[](std::optional<std::vector<std::tuple<float, float>>> costs, bool is_shared)
860+
{
861+
auto label = std::make_unique<VW::cs_label>();
862+
if (is_shared)
863+
{
864+
if (costs.has_value())
865+
{
866+
throw std::invalid_argument("Shared examples cannot have action, cost, or probability.");
867+
}
868+
869+
label->costs.emplace_back(-FLT_MAX, 0, 0.f, 0.f);
870+
return label;
871+
}
872+
873+
if (costs.has_value())
874+
{
875+
for (auto& [class_index, cost] : *costs) { label->costs.emplace_back(cost, class_index, 0.f, 0.f); }
876+
}
877+
878+
return label;
879+
}),
880+
py::kw_only(), py::arg("costs") = py::none(), py::arg("shared") = false, R"docstring(
881+
A label representing a cost sensitive classification problem.
882+
883+
Args:
884+
costs (Optional[List[Tuple[int, float]]]): List of classes and costs. If there is no label, this should be None.
885+
shared (bool): Whether the example represents the shared context
886+
)docstring")
887+
.def_property_readonly(
888+
"shared", [](VW::cs_label& l) -> bool { return is_shared(l); },
889+
R"docstring(
890+
Whether the example represents the shared context.
891+
)docstring")
892+
.def_property(
893+
"costs",
894+
[](VW::cs_label& l) -> std::optional<std::vector<std::tuple<uint32_t, float>>>
895+
{
896+
if (is_shared(l)) { return std::nullopt; }
897+
898+
std::vector<std::tuple<uint32_t, float>> costs;
899+
costs.reserve(l.costs.size());
900+
for (auto& cost : l.costs) { costs.emplace_back(cost.class_index, cost.x); }
901+
return costs;
902+
},
903+
[](VW::cs_label& l, const std::vector<std::tuple<uint32_t, float>>& label)
904+
{
905+
if (is_shared(l)) { throw std::invalid_argument("Shared examples cannot have costs."); }
906+
907+
l.costs.clear();
908+
for (auto& [class_index, cost] : label) { l.costs.emplace_back(cost, class_index, 0.f, 0.f); }
909+
},
910+
R"docstring(
911+
The costs for the example. The format of the costs is (class_index, cost).
912+
)docstring");
913+
844914
py::class_<VW::example, std::shared_ptr<VW::example>>(m, "Example")
845915
.def(py::init(
846916
[]()
@@ -851,7 +921,7 @@ PYBIND11_MODULE(_core, m)
851921
.def("_is_newline", [](VW::example& ex) -> bool { return ex.is_newline; })
852922
.def("_get_label",
853923
[](VW::example& ex, VW::label_type_t label_type)
854-
-> std::variant<py_simple_label, VW::multiclass_label, VW::cb_label, std::monostate>
924+
-> std::variant<py_simple_label, VW::multiclass_label, VW::cb_label, VW::cs_label, std::monostate>
855925
{
856926
switch (label_type)
857927
{
@@ -864,6 +934,8 @@ PYBIND11_MODULE(_core, m)
864934
return ex.l.multi;
865935
case VW::label_type_t::CB:
866936
return ex.l.cb;
937+
case VW::label_type_t::CS:
938+
return ex.l.cs;
867939
case VW::label_type_t::NOLABEL:
868940
return std::monostate();
869941
default:
@@ -872,7 +944,7 @@ PYBIND11_MODULE(_core, m)
872944
})
873945
.def("_set_label",
874946
[](VW::example& ex,
875-
const std::variant<py_simple_label*, VW::multiclass_label*, VW::cb_label*, std::monostate>& label) -> void
947+
const std::variant<py_simple_label*, VW::multiclass_label*, VW::cb_label*, VW::cs_label*, std::monostate>& label) -> void
876948
{
877949
std::visit(
878950
overloaded{
@@ -887,6 +959,7 @@ PYBIND11_MODULE(_core, m)
887959
},
888960
[&](VW::multiclass_label* multiclass_label) { ex.l.multi = *multiclass_label; },
889961
[&](VW::cb_label* cb_label) { ex.l.cb = *cb_label; },
962+
[&](VW::cs_label* cs_label) { ex.l.cs = *cs_label; },
890963
},
891964
label);
892965
});

src/vowpal_wabbit_next/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .delta import ModelDelta, calculate_delta, apply_delta, merge_deltas
88
from .cli_driver import CLIError, run_cli_driver
99
from .prediction_type import PredictionType
10-
from .labels import LabelType, SimpleLabel, MulticlassLabel, CBLabel
10+
from .labels import LabelType, SimpleLabel, MulticlassLabel, CBLabel, CSLabel
1111

1212

1313
VW_COMMIT: str = _vw_commit
@@ -38,4 +38,5 @@
3838
"SimpleLabel",
3939
"MulticlassLabel",
4040
"CBLabel",
41+
"CSLabel",
4142
]

src/vowpal_wabbit_next/_core.pyi

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import typing
44

55
__all__ = [
66
"CBLabel",
7+
"CSLabel",
78
"DenseParameters",
89
"Example",
910
"LabelType",
@@ -20,6 +21,9 @@ class CBLabel():
2021
"""
2122
A label representing a contextual bandit problem.
2223
24+
.. note::
25+
Currently the label can only contain 1 or 0 cb costs. There is a mode in vw for CB (non-adf) that allows for multiple cb_classes per example, but it is not currently accessible via direct label access. If creating examples/labels from parsed input it should still work as expected. If you need this feature, please open an issue on the github repo.
26+
2327
Args:
2428
label (Optional[Union[Tuple[float, float], Tuple[int, float, float]]): This is (action, cost, probability). The same rules as VW apply for if the action can be left out of the tuple.
2529
weight (float): The weight of the example.
@@ -57,13 +61,42 @@ class CBLabel():
5761
The weight of the example.
5862
"""
5963
pass
64+
class CSLabel():
65+
def __init__(self, *, costs: typing.Optional[typing.List[typing.Tuple[float, float]]] = None, shared: bool = False) -> None:
66+
"""
67+
A label representing a cost sensitive classification problem.
68+
69+
Args:
70+
costs (Optional[List[Tuple[int, float]]]): List of classes and costs. If there is no label, this should be None.
71+
shared (bool): Whether the example represents the shared context
72+
"""
73+
@property
74+
def costs(self) -> typing.Optional[typing.List[typing.Tuple[int, float]]]:
75+
"""
76+
The costs for the example. The format of the costs is (class_index, cost).
77+
78+
:type: typing.Optional[typing.List[typing.Tuple[int, float]]]
79+
"""
80+
@costs.setter
81+
def costs(self, arg1: typing.List[typing.Tuple[int, float]]) -> None:
82+
"""
83+
The costs for the example. The format of the costs is (class_index, cost).
84+
"""
85+
@property
86+
def shared(self) -> bool:
87+
"""
88+
Whether the example represents the shared context.
89+
90+
:type: bool
91+
"""
92+
pass
6093
class DenseParameters():
6194
pass
6295
class Example():
6396
def __init__(self) -> None: ...
64-
def _get_label(self, arg0: LabelType) -> typing.Union[SimpleLabel, MulticlassLabel, CBLabel, None]: ...
97+
def _get_label(self, arg0: LabelType) -> typing.Union[SimpleLabel, MulticlassLabel, CBLabel, CSLabel, None]: ...
6598
def _is_newline(self) -> bool: ...
66-
def _set_label(self, arg0: typing.Union[SimpleLabel, MulticlassLabel, CBLabel, None]) -> None: ...
99+
def _set_label(self, arg0: typing.Union[SimpleLabel, MulticlassLabel, CBLabel, CSLabel, None]) -> None: ...
67100
pass
68101
class LabelType():
69102
def __eq__(self, other: object) -> bool: ...

src/vowpal_wabbit_next/example.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from vowpal_wabbit_next import _core
2-
from vowpal_wabbit_next.labels import LabelType, SimpleLabel, MulticlassLabel, CBLabel
2+
from vowpal_wabbit_next.labels import (
3+
LabelType,
4+
SimpleLabel,
5+
MulticlassLabel,
6+
CBLabel,
7+
CSLabel,
8+
)
39
from typing import Optional, Union
410

511

@@ -27,7 +33,7 @@ def __init__(
2733
self._example = _core.Example()
2834
self.label_type = LabelType.NoLabel
2935

30-
def get_label(self) -> Union[SimpleLabel, MulticlassLabel, CBLabel, None]:
36+
def get_label(self) -> Union[SimpleLabel, MulticlassLabel, CBLabel, CSLabel, None]:
3137
"""Get the label of the example.
3238
3339
Returns:
@@ -36,7 +42,7 @@ def get_label(self) -> Union[SimpleLabel, MulticlassLabel, CBLabel, None]:
3642
return self._example._get_label(self.label_type)
3743

3844
def set_label(
39-
self, label: Union[SimpleLabel, MulticlassLabel, CBLabel, None]
45+
self, label: Union[SimpleLabel, MulticlassLabel, CBLabel, CSLabel, None]
4046
) -> None:
4147
"""Set the label of the example.
4248
@@ -52,6 +58,8 @@ def set_label(
5258
self.label_type = LabelType.Multiclass
5359
elif isinstance(label, CBLabel):
5460
self.label_type = LabelType.CB
61+
elif isinstance(label, CSLabel):
62+
self.label_type = LabelType.CS
5563
elif label is None:
5664
self.label_type = LabelType.NoLabel
5765
else:

src/vowpal_wabbit_next/labels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
SimpleLabel = _core.SimpleLabel
55
MulticlassLabel = _core.MulticlassLabel
66
CBLabel = _core.CBLabel
7+
CSLabel = _core.CSLabel

tests/test_labels.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,39 @@ def test_cb_label() -> None:
4343

4444
example.set_label(None)
4545
assert example.get_label() is None
46+
47+
48+
def test_cs_label() -> None:
49+
example = vw.Example()
50+
assert example.get_label() is None
51+
example.set_label(vw.CSLabel(costs=[(1, 1.2)]))
52+
assert isinstance(example.get_label(), vw.CSLabel)
53+
assert len(example.get_label().costs) == 1
54+
assert example.get_label().costs[0] == pytest.approx((1, 1.2))
55+
56+
with pytest.raises(ValueError):
57+
example.set_label(vw.CSLabel(costs=[(1, 1.2)], shared=True))
58+
59+
example.set_label(vw.CSLabel(shared=True))
60+
assert isinstance(example.get_label(), vw.CSLabel)
61+
assert example.get_label().costs is None
62+
assert example.get_label().shared == True
63+
64+
example.set_label(None)
65+
assert example.get_label() is None
66+
67+
68+
def test_cs_label_parsed() -> None:
69+
workspace = vw.Workspace(["--csoaa_ldf=mc"])
70+
parser = vw.TextFormatParser(workspace)
71+
shared_ex = parser.parse_line("shared | a b c")
72+
labeled_ex = parser.parse_line("1:0.5 | a b c")
73+
assert isinstance(shared_ex.get_label(), vw.CSLabel)
74+
assert isinstance(labeled_ex.get_label(), vw.CSLabel)
75+
76+
assert shared_ex.get_label().costs is None
77+
assert shared_ex.get_label().shared == True
78+
79+
assert len(labeled_ex.get_label().costs) == 1
80+
assert labeled_ex.get_label().costs[0] == pytest.approx((1, 0.5))
81+
assert labeled_ex.get_label().shared == False

0 commit comments

Comments
 (0)