Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions python/torch_tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def test_par_errors() -> None:
with pytest.raises(RuntimeError, match="Already borrowed"):
fill_next_token_bitmask_par(exec, [(g0, 0), (g1, 1), (g1, 2)], mask)

with pytest.raises(TypeError, match="cannot be converted"):
with pytest.raises(TypeError, match="cannot be cast"):
l = [1, (g1, 0), (g1, 1)]
fill_next_token_bitmask_par(exec, l, mask) # type: ignore

with pytest.raises(TypeError, match="cannot be converted"):
with pytest.raises(TypeError, match="cannot be cast"):
l = [(tokenizer(), 0)]
fill_next_token_bitmask_par(exec, l, mask) # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion python_ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ llguidance = { workspace = true }
toktrie_hf_tokenizers = { workspace = true }
toktrie_tiktoken = { workspace = true }
bytemuck = "1.21.0"
pyo3 = {version = "0.24.1", features = ["extension-module", "abi3-py39", "anyhow"]}
pyo3 = {version = "0.27.1", features = ["extension-module", "abi3-py39", "anyhow"]}
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.138"
rayon = "1.10.0"
Expand Down
2 changes: 1 addition & 1 deletion python_ext/src/llinterpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl LLInterpreter {

fn compute_mask(&mut self, py: Python<'_>) -> PyResult<(Option<Cow<'_, [u8]>>, String)> {
let r = py
.allow_threads(|| self.inner.compute_mask())
.detach(|| self.inner.compute_mask())
.map_err(val_error)?
.clone();
let mask = if r.is_stop() || r.unconditional_splice().is_some() {
Expand Down
22 changes: 11 additions & 11 deletions python_ext/src/llmatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl LLExecutor {

let mut mut_refs = vec![];
for ent in interpreters.iter() {
let tupl = ent.downcast::<PyTuple>()?;
let tupl = ent.cast::<PyTuple>()?;
if tupl.len() != 2 {
return Err(PyValueError::new_err("Expecting (LLMatcher, int) tuple"));
}
Expand Down Expand Up @@ -87,7 +87,7 @@ impl LLExecutor {

use rayon::prelude::*;

py.allow_threads(|| {
py.detach(|| {
self.pool.install(|| {
mut_refs2.into_par_iter().for_each(|(interp, idx)| {
interp.unsafe_compute_mask_ptr_inner(
Expand Down Expand Up @@ -115,7 +115,7 @@ impl LLExecutor {

let mut mut_refs = vec![];
for ent in interpreters.iter() {
let tupl = ent.downcast::<PyTuple>()?;
let tupl = ent.cast::<PyTuple>()?;
if tupl.len() != 3 {
return Err(PyValueError::new_err(
"Expecting (LLMatcher, int, List[int]) tuple",
Expand Down Expand Up @@ -151,7 +151,7 @@ impl LLExecutor {

use rayon::prelude::*;

py.allow_threads(|| {
py.detach(|| {
self.pool.install(|| {
mut_refs2
.into_par_iter()
Expand Down Expand Up @@ -247,7 +247,7 @@ fn new_matcher(
let logger = Logger::new(0, std::cmp::max(0, log_level) as u32);
// constructing a grammar can take on the order of 100ms
// for very large grammars, so we drop the GIL here
let inner = py.allow_threads(|| {
let inner = py.detach(|| {
let mut r = fact.create_parser_from_init_ext(
GrammarInit::Serialized(grammar),
logger,
Expand Down Expand Up @@ -306,7 +306,7 @@ impl LLMatcher {
py: Python<'_>,
) -> String {
match extract_grammar(grammar) {
Ok((_, grammar)) => py.allow_threads(|| {
Ok((_, grammar)) => py.detach(|| {
GrammarInit::Serialized(grammar)
.validate(
tokenizer.map(|t| t.factory().tok_env().clone()),
Expand All @@ -328,7 +328,7 @@ impl LLMatcher {
py: Python<'_>,
) -> (bool, Vec<String>) {
match extract_grammar(grammar) {
Ok((_, grammar)) => py.allow_threads(|| {
Ok((_, grammar)) => py.detach(|| {
GrammarInit::Serialized(grammar)
.validate(
tokenizer.map(|t| t.factory().tok_env().clone()),
Expand Down Expand Up @@ -428,7 +428,7 @@ impl LLMatcher {
py: Python<'_>,
) -> PyResult<()> {
self.validate_mask_ptr(trg_ptr, trg_bytes)?;
py.allow_threads(|| self.unsafe_compute_mask_ptr_inner(trg_ptr, trg_bytes));
py.detach(|| self.unsafe_compute_mask_ptr_inner(trg_ptr, trg_bytes));
Ok(())
}

Expand All @@ -440,14 +440,14 @@ impl LLMatcher {
py: Python<'_>,
) -> PyResult<()> {
self.validate_mask_ptr(trg_ptr, trg_bytes)?;
py.allow_threads(|| {
py.detach(|| {
self.unsafe_compute_mask_ptr_inner_with_draft_tokens(trg_ptr, trg_bytes, draft_tokens)
});
Ok(())
}

fn compute_logit_bias(&mut self, py: Python<'_>) -> Cow<'_, [u8]> {
py.allow_threads(|| {
py.detach(|| {
let m = self.compute_mask_or_eos();
let mut res = vec![0u8; m.len()];
m.iter_set_entries(|i| res[i] = 200);
Expand All @@ -456,7 +456,7 @@ impl LLMatcher {
}

fn compute_bitmask(&mut self, py: Python<'_>) -> Cow<'_, [u8]> {
py.allow_threads(|| {
py.detach(|| {
let m = self.compute_mask_or_eos();
Cow::Owned(bytemuck::cast_slice(m.as_slice()).to_vec())
})
Expand Down
2 changes: 1 addition & 1 deletion python_ext/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl TokenizerEnv for PyTokenizer {

fn tokenize_bytes(&self, utf8bytes: &[u8]) -> Vec<TokenId> {
self.tok_trie.tokenize_with_greedy_fallback(utf8bytes, |s| {
Python::with_gil(|py| {
Python::attach(|py| {
let r = self.tokenizer_fun.call1(py, (s,)).unwrap();
r.extract::<Vec<TokenId>>(py).unwrap()
})
Expand Down
8 changes: 4 additions & 4 deletions python_ext/src/pyjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Serialize for SerializePyObject<'_> {
extract!(String);
extract!(bool);

if let Ok(x) = self.v.downcast::<PyFloat>() {
if let Ok(x) = self.v.cast::<PyFloat>() {
return x.value().serialize(serializer);
}

Expand All @@ -78,7 +78,7 @@ impl Serialize for SerializePyObject<'_> {
return serializer.serialize_unit();
}

if let Ok(x) = self.v.downcast::<PyDict>() {
if let Ok(x) = self.v.cast::<PyDict>() {
let mut map = serializer.serialize_map(Some(x.len()))?;
for (key, value) in x {
if let Ok(key) = key.str() {
Expand All @@ -94,15 +94,15 @@ impl Serialize for SerializePyObject<'_> {
return map.end();
}

if let Ok(x) = self.v.downcast::<PyList>() {
if let Ok(x) = self.v.cast::<PyList>() {
let mut seq = serializer.serialize_seq(Some(x.len()))?;
for element in x {
seq.serialize_element(&SerializePyObject { v: element })?
}
return seq.end();
}

if let Ok(x) = self.v.downcast::<PyTuple>() {
if let Ok(x) = self.v.cast::<PyTuple>() {
let mut seq = serializer.serialize_seq(Some(x.len()))?;
for element in x {
seq.serialize_element(&SerializePyObject { v: element })?
Expand Down