33- Docs: https://aws.amazon.com/sagemaker-ai/
44"""
55
6- import base64
76import json
87import logging
9- import mimetypes
108import uuid
119from typing import Any , Iterable , Optional , TypedDict , Union
1210
13- from typing_extensions import Unpack , override
14-
1511import boto3
1612from botocore .config import Config as BotocoreConfig
13+ from typing_extensions import Unpack , override
1714
18- from strands .types .content import Messages , Message , ContentBlock
15+ from strands .types .content import ContentBlock , Message , Messages
1916from strands .types .media import DocumentContent , ImageContent
2017from strands .types .models import Model
21- from strands .types .streaming import StreamEvent , StopReason
18+ from strands .types .streaming import StopReason , StreamEvent
2219from strands .types .tools import ToolSpec
2320
2421logger = logging .getLogger (__name__ )
@@ -35,7 +32,7 @@ class SageMakerAIModel(Model):
3532 - Endpoint not found error handling
3633 - Inference component capacity error handling with automatic retries
3734 """
38-
35+
3936 class ModelConfig (TypedDict , total = False ):
4037 """Configuration options for SageMaker models.
4138
@@ -48,6 +45,7 @@ class ModelConfig(TypedDict, total=False):
4845 temperature: Controls randomness in generation (higher = more random).
4946 top_p: Controls diversity via nucleus sampling (alternative to temperature).
5047 """
48+
5149 additional_args : Optional [dict [str , Any ]]
5250 endpoint_name : str
5351 inference_component_name : Optional [str ]
@@ -64,7 +62,7 @@ def __init__(
6462 boto_session : Optional [boto3 .Session ] = None ,
6563 boto_client_config : Optional [BotocoreConfig ] = None ,
6664 region_name : Optional [str ] = None ,
67- ** model_config : Unpack ["SageMakerAIModel. ModelConfig" ],
65+ ** model_config : Unpack [ModelConfig ],
6866 ):
6967 """Initialize provider instance.
7068
@@ -73,27 +71,105 @@ def __init__(
7371 inference_component_name: The name of the inference component to use.
7472 boto_session: Boto Session to use when calling the SageMaker Runtime.
7573 boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
76- retry_attempts: Number of retry attempts for capacity errors (default: 3).
77- retry_delay: Delay in seconds between retry attempts (default: 30).
74+ region_name: AWS region name to use for the SageMaker Runtime client.
7875 **model_config: Model parameters for the SageMaker request payload.
7976 """
8077 self .config = SageMakerAIModel .ModelConfig (
81- endpoint_name = endpoint_name ,
82- inference_component_name = inference_component_name
78+ endpoint_name = endpoint_name , inference_component_name = inference_component_name
8379 )
8480 self .update_config (** model_config )
8581
86- # logger.debug("endpoint=%s, config=%s | initializing", self.config["endpoint_name"], self.config)
8782 logger .debug ("config=<%s> | initializing" , self .config )
8883
89- session = boto_session or boto3 .Session (
90- region_name = region_name ,
91- )
84+ if boto_session :
85+ session = boto_session
86+ elif region_name :
87+ session = boto3 .Session (region_name = region_name )
88+ else :
89+ session = boto3 .Session ()
90+
9291 self .client = session .client (
9392 service_name = "sagemaker-runtime" ,
9493 config = boto_client_config ,
9594 )
9695
96+ def _format_message (self , message : Message , content : ContentBlock ) -> dict [str , Any ]:
97+ """Format a message content block for SageMaker API.
98+
99+ Args:
100+ message: The message containing the content.
101+ content: The content block to format.
102+
103+ Returns:
104+ Formatted message for SageMaker API.
105+ """
106+ if "text" in content :
107+ return {"role" : message ["role" ], "content" : content ["text" ]}
108+
109+ if "image" in content :
110+ # Convert bytes to base64 string for JSON serialization
111+ image_bytes = content ["image" ]["source" ]["bytes" ]
112+ image_bytes = image_bytes .decode ("utf-8" ) if isinstance (image_bytes , bytes ) else image_bytes
113+ return {"role" : message ["role" ], "images" : [image_bytes ]}
114+
115+ if "toolUse" in content :
116+ return {
117+ "role" : "assistant" ,
118+ "tool_calls" : [
119+ {
120+ "id" : content ["toolUse" ]["toolUseId" ],
121+ "type" : "function" ,
122+ "function" : {
123+ "name" : content ["toolUse" ]["name" ],
124+ "arguments" : json .dumps (content ["toolUse" ]["input" ]),
125+ },
126+ }
127+ ],
128+ }
129+
130+ if "toolResult" in content :
131+ result_content : Union [str , ImageContent , DocumentContent , Any ] = None
132+ result_images = []
133+ for toolResultContent in content ["toolResult" ]["content" ]:
134+ if "text" in toolResultContent :
135+ result_content = toolResultContent ["text" ]
136+ elif "json" in toolResultContent :
137+ result_content = toolResultContent ["json" ]
138+ elif "image" in toolResultContent :
139+ result_content = "see images"
140+ # Convert bytes to base64 string for JSON serialization
141+ image_bytes = toolResultContent ["image" ]["source" ]["bytes" ]
142+ image_bytes = image_bytes .decode ("utf-8" ) if isinstance (image_bytes , bytes ) else image_bytes
143+ result_images .append (image_bytes )
144+ else :
145+ result_content = content ["toolResult" ]["content" ]
146+
147+ return {
148+ "role" : "tool" ,
149+ "name" : content ["toolResult" ]["toolUseId" ],
150+ "tool_call_id" : content ["toolResult" ]["toolUseId" ],
151+ "content" : json .dumps (
152+ {
153+ "result" : result_content ,
154+ "status" : content ["toolResult" ]["status" ],
155+ }
156+ ),
157+ ** ({"images" : result_images } if result_images else {}),
158+ }
159+
160+ return {"role" : message ["role" ], "content" : json .dumps (content )}
161+
162+ def _format_messages (self , messages : Messages ) -> list [dict [str , Any ]]:
163+ """Format all messages for SageMaker API.
164+
165+ Args:
166+ messages: List of messages to format.
167+
168+ Returns:
169+ List of formatted messages.
170+ """
171+ return [self ._format_message (message , content ) for message in messages for content in message ["content" ]]
172+
97173 @override
98174 def update_config (self , ** model_config : Unpack [ModelConfig ]) -> None : # type: ignore
99175 """Update the SageMaker AI Model configuration with the provided arguments.
@@ -108,7 +184,7 @@ def get_config(self) -> ModelConfig:
108184 """Get the SageMaker AI Model configuration.
109185
110186 Returns:
111- The Bedrok model configuration.
187+ The SageMaker model configuration.
112188 """
113189 return self .config
114190
@@ -126,73 +202,7 @@ def format_request(
126202 Returns:
127203 An SageMaker AI chat streaming request.
128204 """
129-
130- def format_message (message : Message , content : ContentBlock ) -> dict [str , Any ]:
131- if "text" in content :
132- return {"role" : message ["role" ], "content" : content ["text" ]}
133-
134- if "image" in content :
135- mime_type = mimetypes .types_map .get (f".{ content ['image' ]['format' ]} " , "application/octet-stream" )
136- image_data = base64 .b64encode (content ["image" ]["source" ]["bytes" ]).decode ("utf-8" )
137- return {
138- "image_url" : {
139- "url" : f"data:{ mime_type } ;base64,{ image_data } " ,
140- },
141- "type" : "image_url" ,
142- }
143-
144- if "toolUse" in content :
145- return {
146- "role" : "assistant" ,
147- "tool_calls" : [
148- {
149- 'id' : content ["toolUse" ]["toolUseId" ],
150- 'type' :'function' ,
151- "function" : {
152- "name" : content ["toolUse" ]["name" ],
153- "arguments" : json .dumps (content ["toolUse" ]["input" ]),
154- }
155- }
156- ],
157- }
158-
159- if "toolResult" in content :
160- result_content : Union [str , ImageContent , DocumentContent , Any ] = None
161- result_images = []
162- for toolResultContent in content ["toolResult" ]["content" ]:
163- if "text" in toolResultContent :
164- result_content = toolResultContent ["text" ]
165- elif "json" in toolResultContent :
166- result_content = toolResultContent ["json" ]
167- elif "image" in toolResultContent :
168- result_content = "see images"
169- # Convert bytes to base64 string for JSON serialization
170- image_bytes = toolResultContent ["image" ]["source" ]["bytes" ]
171- if isinstance (image_bytes , bytes ):
172- image_bytes = image_bytes .decode ('utf-8' )
173- result_images .append (image_bytes )
174- else :
175- result_content = content ["toolResult" ]["content" ]
176-
177- return {
178- "role" : "tool" ,
179- "name" : content ["toolResult" ]["toolUseId" ],
180- "tool_call_id" : content ["toolResult" ]["toolUseId" ],
181- "content" : json .dumps (
182- {
183- "result" : result_content ,
184- "status" : content ["toolResult" ]["status" ],
185- }
186- ),
187- ** ({"images" : result_images } if result_images else {}),
188- }
189-
190- return {"role" : message ["role" ], "content" : json .dumps (content )}
191-
192- def format_messages () -> list [dict [str , Any ]]:
193- return [format_message (message , content ) for message in messages for content in message ["content" ]]
194-
195- formatted_messages = format_messages ()
205+ formatted_messages = self ._format_messages (messages )
196206
197207 payload = {
198208 "messages" : [
@@ -227,9 +237,9 @@ def format_messages() -> list[dict[str, Any]]:
227237 try :
228238 if message ["content" ] is None or message ["content" ] == "" :
229239 message ["content" ] = "Thinking ..."
230- if message [' role' ] == ' assistant' and message [' content' ] == ' Thinking...\n ' :
231- continue
232- except :
240+ if message [" role" ] == " assistant" and message [" content" ] == " Thinking...\n " :
241+ continue
242+ except KeyError :
233243 pass
234244 messages_new .append (message )
235245 payload ["messages" ] = messages_new
@@ -241,11 +251,11 @@ def format_messages() -> list[dict[str, Any]]:
241251 "ContentType" : "application/json" ,
242252 "Accept" : "application/json" ,
243253 }
244-
254+
245255 # Add InferenceComponentName if provided
246256 if self .config .get ("inference_component_name" ):
247257 request ["InferenceComponentName" ] = self .config ["inference_component_name" ]
248-
258+
249259 return request
250260
251261 @override
@@ -269,7 +279,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
269279 if event ["data_type" ] == "text" :
270280 return {"contentBlockStart" : {"start" : {}}}
271281 # Random string of 9 alphanumerical characters
272- tool_id = '' .join (uuid .uuid4 ().hex [:9 ])
282+ tool_id = "" .join (uuid .uuid4 ().hex [:9 ])
273283 tool_name = event ["data" ]["function" ]["name" ]
274284 return {"contentBlockStart" : {"start" : {"toolUse" : {"name" : tool_name , "toolUseId" : tool_id }}}}
275285
@@ -306,7 +316,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
306316 },
307317 }
308318 else :
309- raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} | unknown type" )
319+ raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} > | unknown type" )
310320
311321 @override
312322 def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
@@ -329,10 +339,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
329339 # Wait until all the answer has been streamed
330340 final_response = ""
331341 for event in response ["Body" ]:
332- chunk_data = event [' PayloadPart' ][ ' Bytes' ].decode ("utf-8" )
342+ chunk_data = event [" PayloadPart" ][ " Bytes" ].decode ("utf-8" )
333343 final_response += chunk_data
334344 final_response_json = json .loads (final_response )
335-
345+
336346 # send messages for tool execution
337347 tool_requested = False
338348 message = final_response_json ["choices" ][0 ]["message" ]
@@ -348,5 +358,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
348358
349359 # Close the message
350360 yield {"chunk_type" : "content_stop" , "data_type" : "text" }
351- yield {"chunk_type" : "message_stop" , "data" : "tool_use" if tool_requested else final_response_json ["choices" ][0 ]["finish_reason" ]}
361+ yield {
362+ "chunk_type" : "message_stop" ,
363+ "data" : "tool_use" if tool_requested else final_response_json ["choices" ][0 ]["finish_reason" ],
364+ }
352365 yield {"chunk_type" : "metadata" , "data" : final_response_json ["usage" ]}
0 commit comments