Skip to content

Commit 8c28040

Browse files
committed
feat: preserve logprobs from chat completions API in ModelResponse
The SDK already accepts `top_logprobs` in ModelSettings and passes it to the API, but the logprobs returned in the response were discarded during conversion. This change: 1. Adds an optional `logprobs` field to ModelResponse dataclass 2. Extracts logprobs from `choice.logprobs.content` in the chat completions model and includes them in the ModelResponse This enables use cases like RLHF training, confidence scoring, and uncertainty estimation that require access to token-level log probabilities.
1 parent 9fcc68f commit 8c28040

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/agents/items.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,13 @@ class ModelResponse:
356356
be passed to `Runner.run`.
357357
"""
358358

359+
logprobs: list[Any] | None = None
360+
"""Token log probabilities from the model response.
361+
Only populated when using the chat completions API with `top_logprobs` set in ModelSettings.
362+
Each element corresponds to a token and contains the token string, log probability, and
363+
optionally the top alternative tokens with their log probabilities.
364+
"""
365+
359366
def to_input_items(self) -> list[TResponseInputItem]:
360367
"""Convert the output into a list of input items suitable for passing to the model."""
361368
# We happen to know that the shape of the Pydantic output items are the same as the

src/agents/models/openai_chatcompletions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,15 @@ async def get_response(
119119

120120
items = Converter.message_to_output_items(message) if message is not None else []
121121

122+
logprobs_data = None
123+
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
124+
logprobs_data = [lp.model_dump() for lp in first_choice.logprobs.content]
125+
122126
return ModelResponse(
123127
output=items,
124128
usage=usage,
125129
response_id=None,
130+
logprobs=logprobs_data,
126131
)
127132

128133
async def stream_response(

0 commit comments

Comments
 (0)