Skip to content

Commit fba1b77

Browse files
committed
refactor: attach logprobs to ResponseOutputText for Responses API consistency
Instead of adding a separate logprobs field to ModelResponse, attach logprobs directly to ResponseOutputText content parts. This makes the chat completions API behavior consistent with the Responses API. - Add conversion helpers in chatcmpl_helpers.py - Update streaming to include logprobs in delta events and accumulate - Attach logprobs to text parts in non-streaming responses - Add tests for both streaming and non-streaming logprobs
1 parent 8c28040 commit fba1b77

File tree

6 files changed

+276
-13
lines changed

6 files changed

+276
-13
lines changed

src/agents/items.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,6 @@ class ModelResponse:
356356
be passed to `Runner.run`.
357357
"""
358358

359-
logprobs: list[Any] | None = None
360-
"""Token log probabilities from the model response.
361-
Only populated when using the chat completions API with `top_logprobs` set in ModelSettings.
362-
Each element corresponds to a token and contains the token string, log probability, and
363-
optionally the top alternative tokens with their log probabilities.
364-
"""
365-
366359
def to_input_items(self) -> list[TResponseInputItem]:
367360
"""Convert the output into a list of input items suitable for passing to the model."""
368361
# We happen to know that the shape of the Pydantic output items are the same as the

src/agents/models/chatcmpl_helpers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from contextvars import ContextVar
44

55
from openai import AsyncOpenAI
6+
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
7+
from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob
8+
from openai.types.responses.response_text_delta_event import (
9+
Logprob as DeltaLogprob,
10+
LogprobTopLogprob as DeltaTopLogprob,
11+
)
612

713
from ..model_settings import ModelSettings
814
from ..version import __version__
@@ -41,3 +47,54 @@ def get_stream_options_param(
4147
)
4248
stream_options = {"include_usage": include_usage} if include_usage is not None else None
4349
return stream_options
50+
51+
@classmethod
52+
def convert_logprobs_for_output_text(
53+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
54+
) -> list[Logprob] | None:
55+
if not logprobs:
56+
return None
57+
58+
converted: list[Logprob] = []
59+
for token_logprob in logprobs:
60+
converted.append(
61+
Logprob(
62+
token=token_logprob.token,
63+
logprob=token_logprob.logprob,
64+
bytes=token_logprob.bytes or [],
65+
top_logprobs=[
66+
LogprobTopLogprob(
67+
token=top_logprob.token,
68+
logprob=top_logprob.logprob,
69+
bytes=top_logprob.bytes or [],
70+
)
71+
for top_logprob in token_logprob.top_logprobs
72+
],
73+
)
74+
)
75+
return converted
76+
77+
@classmethod
78+
def convert_logprobs_for_text_delta(
79+
cls, logprobs: list[ChatCompletionTokenLogprob] | None
80+
) -> list[DeltaLogprob] | None:
81+
if not logprobs:
82+
return None
83+
84+
converted: list[DeltaLogprob] = []
85+
for token_logprob in logprobs:
86+
converted.append(
87+
DeltaLogprob(
88+
token=token_logprob.token,
89+
logprob=token_logprob.logprob,
90+
top_logprobs=[
91+
DeltaTopLogprob(
92+
token=top_logprob.token,
93+
logprob=top_logprob.logprob,
94+
)
95+
for top_logprob in token_logprob.top_logprobs
96+
]
97+
or None,
98+
)
99+
)
100+
return converted

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
4343

4444
from ..items import TResponseStreamEvent
45+
from .chatcmpl_helpers import ChatCmplHelpers
4546
from .fake_id import FAKE_RESPONSES_ID
4647

4748

@@ -105,6 +106,7 @@ async def handle_stream(
105106
continue
106107

107108
delta = chunk.choices[0].delta
109+
choice_logprobs = chunk.choices[0].logprobs
108110

109111
# Handle thinking blocks from Anthropic (for preserving signatures)
110112
if hasattr(delta, "thinking_blocks") and delta.thinking_blocks:
@@ -266,6 +268,12 @@ async def handle_stream(
266268
type="response.content_part.added",
267269
sequence_number=sequence_number.get_and_increment(),
268270
)
271+
delta_logprobs = ChatCmplHelpers.convert_logprobs_for_text_delta(
272+
choice_logprobs.content if choice_logprobs else None
273+
) or []
274+
output_logprobs = ChatCmplHelpers.convert_logprobs_for_output_text(
275+
choice_logprobs.content if choice_logprobs else None
276+
)
269277
# Emit the delta for this segment of content
270278
yield ResponseTextDeltaEvent(
271279
content_index=state.text_content_index_and_output[0],
@@ -275,10 +283,15 @@ async def handle_stream(
275283
is not None, # fixed 0 -> 0 or 1
276284
type="response.output_text.delta",
277285
sequence_number=sequence_number.get_and_increment(),
278-
logprobs=[],
286+
logprobs=delta_logprobs,
279287
)
280288
# Accumulate the text into the response part
281289
state.text_content_index_and_output[1].text += delta.content
290+
if output_logprobs:
291+
existing_logprobs = state.text_content_index_and_output[1].logprobs or []
292+
state.text_content_index_and_output[1].logprobs = (
293+
existing_logprobs + output_logprobs
294+
)
282295

283296
# Handle refusals (model declines to answer)
284297
# This is always set by the OpenAI API, but not by others e.g. LiteLLM

src/agents/models/openai_chatcompletions.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from openai.types import ChatModel
1010
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
1111
from openai.types.chat.chat_completion import Choice
12-
from openai.types.responses import Response
12+
from openai.types.responses import (
13+
Response,
14+
ResponseOutputItem,
15+
ResponseOutputMessage,
16+
ResponseOutputText,
17+
)
18+
from openai.types.responses.response_output_text import Logprob
1319
from openai.types.responses.response_prompt_param import ResponsePromptParam
1420

1521
from .. import _debug
@@ -119,17 +125,33 @@ async def get_response(
119125

120126
items = Converter.message_to_output_items(message) if message is not None else []
121127

122-
logprobs_data = None
128+
logprob_models = None
123129
if first_choice and first_choice.logprobs and first_choice.logprobs.content:
124-
logprobs_data = [lp.model_dump() for lp in first_choice.logprobs.content]
130+
logprob_models = ChatCmplHelpers.convert_logprobs_for_output_text(
131+
first_choice.logprobs.content
132+
)
133+
134+
if logprob_models:
135+
self._attach_logprobs_to_output(items, logprob_models)
125136

126137
return ModelResponse(
127138
output=items,
128139
usage=usage,
129140
response_id=None,
130-
logprobs=logprobs_data,
131141
)
132142

143+
def _attach_logprobs_to_output(
144+
self, output_items: list[ResponseOutputItem], logprobs: list[Logprob]
145+
) -> None:
146+
for output_item in output_items:
147+
if not isinstance(output_item, ResponseOutputMessage):
148+
continue
149+
150+
for content in output_item.content:
151+
if isinstance(content, ResponseOutputText):
152+
content.logprobs = logprobs
153+
return
154+
133155
async def stream_response(
134156
self,
135157
system_instructions: str | None,

tests/test_openai_chatcompletions.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
import httpx
77
import pytest
88
from openai import AsyncOpenAI, omit
9-
from openai.types.chat.chat_completion import ChatCompletion, Choice
9+
from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs
1010
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1111
from openai.types.chat.chat_completion_message import ChatCompletionMessage
1212
from openai.types.chat.chat_completion_message_tool_call import ( # type: ignore[attr-defined]
1313
ChatCompletionMessageFunctionToolCall,
1414
Function,
1515
)
16+
from openai.types.chat.chat_completion_token_logprob import (
17+
ChatCompletionTokenLogprob,
18+
TopLogprob,
19+
)
1620
from openai.types.completion_usage import (
1721
CompletionUsage,
1822
PromptTokensDetails,
@@ -98,6 +102,65 @@ async def patched_fetch_response(self, *args, **kwargs):
98102
assert resp.response_id is None
99103

100104

105+
@pytest.mark.allow_call_model_methods
106+
@pytest.mark.asyncio
107+
async def test_get_response_attaches_logprobs(monkeypatch) -> None:
108+
msg = ChatCompletionMessage(role="assistant", content="Hi!")
109+
choice = Choice(
110+
index=0,
111+
finish_reason="stop",
112+
message=msg,
113+
logprobs=ChoiceLogprobs(
114+
content=[
115+
ChatCompletionTokenLogprob(
116+
token="Hi",
117+
logprob=-0.5,
118+
bytes=[1],
119+
top_logprobs=[TopLogprob(token="Hi", logprob=-0.5, bytes=[1])],
120+
),
121+
ChatCompletionTokenLogprob(
122+
token="!",
123+
logprob=-0.1,
124+
bytes=[2],
125+
top_logprobs=[TopLogprob(token="!", logprob=-0.1, bytes=[2])],
126+
),
127+
]
128+
),
129+
)
130+
chat = ChatCompletion(
131+
id="resp-id",
132+
created=0,
133+
model="fake",
134+
object="chat.completion",
135+
choices=[choice],
136+
usage=None,
137+
)
138+
139+
async def patched_fetch_response(self, *args, **kwargs):
140+
return chat
141+
142+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
143+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
144+
resp: ModelResponse = await model.get_response(
145+
system_instructions=None,
146+
input="",
147+
model_settings=ModelSettings(),
148+
tools=[],
149+
output_schema=None,
150+
handoffs=[],
151+
tracing=ModelTracing.DISABLED,
152+
previous_response_id=None,
153+
conversation_id=None,
154+
prompt=None,
155+
)
156+
assert len(resp.output) == 1
157+
assert isinstance(resp.output[0], ResponseOutputMessage)
158+
text_part = resp.output[0].content[0]
159+
assert isinstance(text_part, ResponseOutputText)
160+
assert text_part.logprobs is not None
161+
assert [lp.token for lp in text_part.logprobs] == ["Hi", "!"]
162+
163+
101164
@pytest.mark.allow_call_model_methods
102165
@pytest.mark.asyncio
103166
async def test_get_response_with_refusal(monkeypatch) -> None:

0 commit comments

Comments
 (0)