Skip to content

Commit cd1e112

Browse files
authored
fix(llama.cpp): correctly set grammar triggers (#6432)
* fix(llama.cpp): correctly set grammar triggers Signed-off-by: Ettore Di Giacinto <[email protected]> * Do not enable lazy by default Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 81b31b4 commit cd1e112

File tree

3 files changed

+68
-17
lines changed

3 files changed

+68
-17
lines changed

backend/cpp/llama-cpp/grpc-server.cpp

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void start_llama_server(server_context& ctx_server) {
9292
ctx_server.queue_tasks.start_loop();
9393
}
9494

95-
json parse_options(bool streaming, const backend::PredictOptions* predict)
95+
json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server)
9696
{
9797

9898
// Create now a json data from the prediction options instead
@@ -147,6 +147,28 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
147147
// data["n_probs"] = predict->nprobs();
148148
//TODO: images,
149149

150+
// Serialize grammar triggers from server context to JSON array
151+
if (!ctx_server.params_base.sampling.grammar_triggers.empty()) {
152+
json grammar_triggers = json::array();
153+
for (const auto& trigger : ctx_server.params_base.sampling.grammar_triggers) {
154+
json trigger_json;
155+
trigger_json["value"] = trigger.value;
156+
// Always serialize as WORD type since upstream converts WORD to TOKEN internally
157+
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
158+
grammar_triggers.push_back(trigger_json);
159+
}
160+
data["grammar_triggers"] = grammar_triggers;
161+
}
162+
163+
// Serialize preserved tokens from server context to JSON array
164+
if (!ctx_server.params_base.sampling.preserved_tokens.empty()) {
165+
json preserved_tokens = json::array();
166+
for (const auto& token : ctx_server.params_base.sampling.preserved_tokens) {
167+
preserved_tokens.push_back(common_token_to_piece(ctx_server.ctx, token));
168+
}
169+
data["preserved_tokens"] = preserved_tokens;
170+
}
171+
150172
return data;
151173
}
152174

@@ -207,7 +229,7 @@ static void add_rpc_devices(std::string servers) {
207229
}
208230
}
209231

210-
static void params_parse(const backend::ModelOptions* request,
232+
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request,
211233
common_params & params) {
212234

213235
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
@@ -347,14 +369,14 @@ static void params_parse(const backend::ModelOptions* request,
347369
}
348370

349371
if (request->grammartriggers_size() > 0) {
350-
params.sampling.grammar_lazy = true;
372+
//params.sampling.grammar_lazy = true;
373+
// Store grammar trigger words for processing after model is loaded
351374
for (int i = 0; i < request->grammartriggers_size(); i++) {
375+
const auto & word = request->grammartriggers(i).word();
352376
common_grammar_trigger trigger;
353-
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
354-
trigger.value = request->grammartriggers(i).word();
355-
// trigger.at_start = request->grammartriggers(i).at_start();
356-
params.sampling.grammar_triggers.push_back(trigger);
357-
377+
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
378+
trigger.value = word;
379+
params.sampling.grammar_triggers.push_back(std::move(trigger));
358380
}
359381
}
360382
}
@@ -377,7 +399,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
377399
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
378400
// Implement LoadModel RPC
379401
common_params params;
380-
params_parse(request, params);
402+
params_parse(ctx_server, request, params);
381403

382404
common_init();
383405

@@ -396,6 +418,39 @@ class BackendServiceImpl final : public backend::Backend::Service {
396418
return Status::CANCELLED;
397419
}
398420

421+
// Process grammar triggers now that vocab is available
422+
if (!params.sampling.grammar_triggers.empty()) {
423+
std::vector<common_grammar_trigger> processed_triggers;
424+
for (const auto& trigger : params.sampling.grammar_triggers) {
425+
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
426+
auto ids = common_tokenize(ctx_server.vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
427+
if (ids.size() == 1) {
428+
auto token = ids[0];
429+
// Add the token to preserved_tokens if not already present
430+
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
431+
params.sampling.preserved_tokens.insert(token);
432+
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
433+
}
434+
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
435+
common_grammar_trigger processed_trigger;
436+
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
437+
processed_trigger.value = trigger.value;
438+
processed_trigger.token = token;
439+
processed_triggers.push_back(std::move(processed_trigger));
440+
} else {
441+
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
442+
processed_triggers.push_back(trigger);
443+
}
444+
} else {
445+
processed_triggers.push_back(trigger);
446+
}
447+
}
448+
// Update the grammar triggers in params_base
449+
ctx_server.params_base.sampling.grammar_triggers = std::move(processed_triggers);
450+
// Also update preserved_tokens in params_base
451+
ctx_server.params_base.sampling.preserved_tokens = params.sampling.preserved_tokens;
452+
}
453+
399454
//ctx_server.init();
400455
result->set_message("Loading succeeded");
401456
result->set_success(true);
@@ -406,7 +461,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
406461
}
407462

408463
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
409-
json data = parse_options(true, request);
464+
json data = parse_options(true, request, ctx_server);
410465

411466

412467
//Raise error if embeddings is set to true
@@ -556,7 +611,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
556611
}
557612

558613
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
559-
json data = parse_options(true, request);
614+
json data = parse_options(true, request, ctx_server);
560615

561616
data["stream"] = false;
562617
//Raise error if embeddings is set to true
@@ -691,7 +746,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
691746

692747
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
693748

694-
json body = parse_options(false, request);
749+
json body = parse_options(false, request, ctx_server);
695750

696751
body["stream"] = false;
697752

@@ -872,7 +927,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
872927
}
873928

874929
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
875-
json body = parse_options(false, request);
930+
json body = parse_options(false, request, ctx_server);
876931
body["stream"] = false;
877932

878933
json tokens_response = json::array();

core/backend/options.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
129129
triggers = append(triggers, &pb.GrammarTrigger{
130130
Word: t.Word,
131131
})
132-
133132
}
134133

135134
return &pb.ModelOptions{

core/http/views/index.html

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,6 @@ <h2 class="text-3xl md:text-4xl font-bold text-[#E5E7EB] mb-2">
147147
<div class="flex-1 min-w-0">
148148
<div class="flex items-center justify-between">
149149
<h3 class="font-bold text-xl text-[#E5E7EB] truncate group-hover:text-[#38BDF8] transition-colors">{{.Name}}</h3>
150-
<a href="browse?term={{.Name}}" class="text-[#94A3B8] hover:text-[#38BDF8] transition-colors p-1 rounded-lg hover:bg-[#38BDF8]/10" title="Search for similar models">
151-
<i class="fas fa-search text-sm"></i>
152-
</a>
153150
</div>
154151

155152
<div class="mt-2 flex flex-wrap gap-2">

0 commit comments

Comments
 (0)