Skip to content

Commit 7d95b8b

Browse files
committed
fix: use n_gen from context
1 parent 5047919 commit 7d95b8b

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

transformers/llm/engine/tools/llm_bench.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ struct TestInstance {
154154
bool useMmap;
155155
int nPrompt;
156156
int nGenerate;
157+
std::vector<int64_t> nGenerates;
157158
std::vector<int64_t> prefillUs;
158159
std::vector<int64_t> decodeUs;
159160
std::vector<int64_t> samplesUs;
@@ -184,6 +185,14 @@ struct TestInstance {
184185
return ts;
185186
}
186187

188+
std::vector<double> getTokensPerSecond(std::vector<int64_t> n_tokens, std::vector<int64_t> cost_us) const {
189+
std::vector<double> ts(n_tokens.size());
190+
for (int i = 0; i < n_tokens.size(); ++i) {
191+
ts[i] = 1e6 * n_tokens[i] / cost_us[i];
192+
}
193+
return ts;
194+
}
195+
187196
double getAvgUs(std::vector<double> v) const { return ::avg(v); }
188197
double getStdevUs(std::vector<double> v) const { return ::stdev(v); }
189198
enum fieldType { STRING, BOOL, INT, FLOAT };
@@ -354,7 +363,7 @@ struct markdownPrinter : public Printer {
354363
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.getAvgUs(spd), t.getStdevUs(spd));
355364
value = buf;
356365
} else if (field == "speed(tok/s)") {
357-
auto decode_speed = t.getTokensPerSecond(t.nGenerate, t.decodeUs);
366+
auto decode_speed = t.getTokensPerSecond(t.nGenerates, t.decodeUs);
358367
auto prefill_speed = t.getTokensPerSecond(t.nPrompt, t.prefillUs);
359368
snprintf(buf, sizeof(buf), "%.2f ± %.2f<br>%.2f ± %.2f", t.getAvgUs(prefill_speed), t.getStdevUs(prefill_speed), t.getAvgUs(decode_speed), t.getStdevUs(decode_speed));
360369
value = buf;
@@ -899,6 +908,11 @@ int main(int argc, char ** argv) {
899908
if (i > 0) { // Exclude the first performance value.
900909
t.prefillUs.push_back(prefillTime);
901910
t.decodeUs.push_back(decodeTime);
911+
if (llm->stoped()) {
912+
t.nGenerates.push_back(context->gen_seq_len - 1);
913+
} else {
914+
t.nGenerates.push_back(context->gen_seq_len);
915+
}
902916
}
903917
}
904918
if (printHeader) {

0 commit comments

Comments
 (0)