Skip to content

Commit 952a25a

Browse files
committed
structured output
1 parent b46da6e commit 952a25a

File tree

3 files changed

+79
-90
lines changed

3 files changed

+79
-90
lines changed

src/strands/models/gemini.py

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class GeminiConfig(TypedDict, total=False):
4040
params: Additional model parameters (e.g., temperature).
4141
For a complete list of supported parameters, see
4242
https://ai.google.dev/api/generate-content#generationconfig.
43-
response_schema: TODO
44-
response_mime_type: TODO
4543
"""
4644

4745
model_id: Required[str]
@@ -200,6 +198,7 @@ def _format_request_config(
200198
self,
201199
tool_specs: Optional[list[ToolSpec]],
202200
system_prompt: Optional[str],
201+
params: Optional[dict[str, Any]],
203202
) -> genai.types.GenerateContentConfig:
204203
"""Format Gemini request config.
205204
@@ -208,21 +207,23 @@ def _format_request_config(
208207
Args:
209208
tool_specs: List of tool specifications to make available to the model.
210209
system_prompt: System prompt to provide context to the model.
210+
params: Additional model parameters (e.g., temperature).
211211
212212
Returns:
213213
Gemini request config.
214214
"""
215215
return genai.types.GenerateContentConfig(
216216
system_instruction=system_prompt,
217217
tools=self._format_request_tools(tool_specs),
218-
**(self.config.get("params") or {}),
218+
**(params or {}),
219219
)
220220

221221
def _format_request(
222222
self,
223223
messages: Messages,
224-
tool_specs: Optional[list[ToolSpec]] = None,
225-
system_prompt: Optional[str] = None,
224+
tool_specs: Optional[list[ToolSpec]],
225+
system_prompt: Optional[str],
226+
params: Optional[dict[str, Any]],
226227
) -> dict[str, Any]:
227228
"""Format a Gemini streaming request.
228229
@@ -232,12 +233,13 @@ def _format_request(
232233
messages: List of message objects to be processed by the model.
233234
tool_specs: List of tool specifications to make available to the model.
234235
system_prompt: System prompt to provide context to the model.
236+
params: Additional model parameters (e.g., temperature).
235237
236238
Returns:
237239
A Gemini streaming request.
238240
"""
239241
return {
240-
"config": self._format_request_config(tool_specs, system_prompt).to_json_dict(),
242+
"config": self._format_request_config(tool_specs, system_prompt, params).to_json_dict(),
241243
"contents": [content.to_json_dict() for content in self._format_request_content(messages)],
242244
"model": self.config["model_id"],
243245
}
@@ -356,7 +358,7 @@ async def stream(
356358
Raises:
357359
ModelThrottledException: If the request is throttled by Gemini.
358360
"""
359-
request = self._format_request(messages, tool_specs, system_prompt)
361+
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
360362

361363
try:
362364
response = await self.client.aio.models.generate_content_stream(**request)
@@ -409,6 +411,8 @@ async def structured_output(
409411
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
410412
"""Get structured output from the model using Gemini's native structured output.
411413
414+
- Docs: https://ai.google.dev/gemini-api/docs/structured-output
415+
412416
Args:
413417
output_model: The output model to use for the agent.
414418
prompt: The prompt messages to use for the agent.
@@ -417,59 +421,12 @@ async def structured_output(
417421
418422
Yields:
419423
Model events with the last being the structured output.
420-
421-
Raises:
422-
ValueError: If the model doesn't return valid structured output.
423-
424-
Gemini Structured Output: https://ai.google.dev/gemini-api/docs/structured-output
425424
"""
426-
yield {}
427-
428-
# schema = output_model.model_json_schema() if hasattr(output_model, "model_json_schema") else output_model
429-
430-
# structured_config = {
431-
# "response_mime_type": "application/json",
432-
# "response_schema": schema,
433-
# }
434-
435-
# if "config" in kwargs:
436-
# structured_config.update(kwargs.pop("config"))
437-
438-
# logger.debug("Using Gemini's native structured output with schema: %s", output_model.__name__)
439-
440-
# structured_config.pop("tool_specs", None)
441-
# kwargs.pop("tool_specs", None)
442-
# async_response = self.stream(
443-
# messages=prompt, tool_specs=None, system_prompt=system_prompt, **structured_config, **kwargs
444-
# )
445-
446-
# accumulated_text = []
447-
# stop_reason = None
448-
449-
# async for event in async_response:
450-
# # Don't yield streaming events, only collect the final result
451-
# if "messageStop" in event and "stopReason" in event["messageStop"]:
452-
# stop_reason = event["messageStop"]["stopReason"]
453-
454-
# if "contentBlockDelta" in event:
455-
# delta = event["contentBlockDelta"].get("delta", {})
456-
# if "text" in delta:
457-
# accumulated_text.append(delta["text"])
458-
459-
# full_response = "".join(accumulated_text)
460-
461-
# if not full_response.strip():
462-
# logger.error("Empty response from model when generating structured output")
463-
# raise ValueError("Empty response from model when generating structured output")
464-
465-
# if stop_reason != "end_turn":
466-
# logger.error("Model returned unexpected stop_reason: %s", stop_reason)
467-
# raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "end_turn"')
468-
469-
# try:
470-
# result = output_model.model_validate_json(full_response)
471-
# yield {"output": result}
472-
473-
# except Exception as e:
474-
# logger.error("Failed to create output model from JSON response: %s", str(e))
475-
# raise ValueError(f"Failed to create structured output from Gemini response: {str(e)}") from e
425+
params = {
426+
**(self.config.get("params") or {}),
427+
"response_mime_type": "application/json",
428+
"response_schema": output_model.model_json_schema(),
429+
}
430+
request = self._format_request(prompt, None, system_prompt, params)
431+
response = await self.client.aio.models.generate_content(**request)
432+
yield {"output": output_model.model_validate(response.parsed)}

tests/strands/models/test_gemini.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest.mock
22

3+
import pydantic
34
import pytest
45
from google import genai
56

@@ -38,6 +39,15 @@ def system_prompt():
3839
return "s1"
3940

4041

42+
@pytest.fixture
43+
def weather_output():
44+
class Weather(pydantic.BaseModel):
45+
time: str
46+
weather: str
47+
48+
return Weather(time="12:00", weather="sunny")
49+
50+
4151
def test__init__model_configs(gemini_client, model_id):
4252
_ = gemini_client
4353

@@ -519,3 +529,23 @@ async def test_stream_response_client_exception(gemini_client, model, messages):
519529

520530
with pytest.raises(genai.errors.ClientError, match="INTERNAL"):
521531
await anext(model.stream(messages))
532+
533+
534+
@pytest.mark.asyncio
535+
async def test_structured_output(gemini_client, model, messages, model_id, weather_output):
536+
gemini_client.aio.models.generate_content.return_value = unittest.mock.Mock(parsed=weather_output.model_dump())
537+
538+
tru_response = await anext(model.structured_output(type(weather_output), messages))
539+
exp_response = {"output": weather_output}
540+
assert tru_response == exp_response
541+
542+
exp_request = {
543+
"config": {
544+
"tools": [{"function_declarations": []}],
545+
"response_mime_type": "application/json",
546+
"response_schema": weather_output.model_json_schema(),
547+
},
548+
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
549+
"model": model_id,
550+
}
551+
gemini_client.aio.models.generate_content.assert_called_with(**exp_request)

tests_integ/models/test_model_gemini.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -145,31 +145,33 @@ def test_agent_invoke_document_input(assistant_agent, letter_pdf):
145145
assert "shareholder" in text
146146

147147

148-
# def test_structured_output(tool_agent, weather):
149-
# tru_weather = tool_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
150-
# exp_weather = weather
151-
# assert tru_weather == exp_weather
152-
153-
154-
# @pytest.mark.asyncio
155-
# async def test_agent_structured_output_async(tool_agent, weather):
156-
# tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny")
157-
# exp_weather = weather
158-
# assert tru_weather == exp_weather
159-
160-
161-
# def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow_color):
162-
# content = [
163-
# {"text": "Is this image red, blue, or yellow?"},
164-
# {
165-
# "image": {
166-
# "format": "png",
167-
# "source": {
168-
# "bytes": yellow_img,
169-
# },
170-
# },
171-
# },
172-
# ]
173-
# tru_color = assistant_agent.structured_output(type(yellow_color), content)
174-
# exp_color = yellow_color
175-
# assert tru_color == exp_color
148+
def test_agent_structured_output(assistant_agent, weather):
149+
tru_weather = assistant_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")
150+
exp_weather = weather
151+
assert tru_weather == exp_weather
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_agent_structured_output_async(assistant_agent, weather):
156+
tru_weather = await assistant_agent.structured_output_async(
157+
type(weather), "The time is 12:00 and the weather is sunny"
158+
)
159+
exp_weather = weather
160+
assert tru_weather == exp_weather
161+
162+
163+
def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow_color):
164+
content = [
165+
{"text": "Is this image red, blue, or yellow?"},
166+
{
167+
"image": {
168+
"format": "png",
169+
"source": {
170+
"bytes": yellow_img,
171+
},
172+
},
173+
},
174+
]
175+
tru_color = assistant_agent.structured_output(type(yellow_color), content)
176+
exp_color = yellow_color
177+
assert tru_color == exp_color

0 commit comments

Comments
 (0)