Skip to content

Commit ea6046e

Browse files
committed
Updated test cases and formatting
1 parent cd9df32 commit ea6046e

File tree

3 files changed

+201
-191
lines changed

3 files changed

+201
-191
lines changed

src/strands/models/sagemaker.py

Lines changed: 107 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,19 @@
33
- Docs: https://aws.amazon.com/sagemaker-ai/
44
"""
55

6-
import base64
76
import json
87
import logging
9-
import mimetypes
108
import uuid
119
from typing import Any, Iterable, Optional, TypedDict, Union
1210

13-
from typing_extensions import Unpack, override
14-
1511
import boto3
1612
from botocore.config import Config as BotocoreConfig
13+
from typing_extensions import Unpack, override
1714

18-
from strands.types.content import Messages, Message, ContentBlock
15+
from strands.types.content import ContentBlock, Message, Messages
1916
from strands.types.media import DocumentContent, ImageContent
2017
from strands.types.models import Model
21-
from strands.types.streaming import StreamEvent, StopReason
18+
from strands.types.streaming import StopReason, StreamEvent
2219
from strands.types.tools import ToolSpec
2320

2421
logger = logging.getLogger(__name__)
@@ -35,7 +32,7 @@ class SageMakerAIModel(Model):
3532
- Endpoint not found error handling
3633
- Inference component capacity error handling with automatic retries
3734
"""
38-
35+
3936
class ModelConfig(TypedDict, total=False):
4037
"""Configuration options for SageMaker models.
4138
@@ -48,6 +45,7 @@ class ModelConfig(TypedDict, total=False):
4845
temperature: Controls randomness in generation (higher = more random).
4946
top_p: Controls diversity via nucleus sampling (alternative to temperature).
5047
"""
48+
5149
additional_args: Optional[dict[str, Any]]
5250
endpoint_name: str
5351
inference_component_name: Optional[str]
@@ -64,7 +62,7 @@ def __init__(
6462
boto_session: Optional[boto3.Session] = None,
6563
boto_client_config: Optional[BotocoreConfig] = None,
6664
region_name: Optional[str] = None,
67-
**model_config: Unpack["SageMakerAIModel.ModelConfig"],
65+
**model_config: Unpack[ModelConfig],
6866
):
6967
"""Initialize provider instance.
7068
@@ -73,27 +71,105 @@ def __init__(
7371
inference_component_name: The name of the inference component to use.
7472
boto_session: Boto Session to use when calling the SageMaker Runtime.
7573
boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
76-
retry_attempts: Number of retry attempts for capacity errors (default: 3).
77-
retry_delay: Delay in seconds between retry attempts (default: 30).
74+
region_name: AWS region name to use for the SageMaker Runtime client.
7875
**model_config: Model parameters for the SageMaker request payload.
7976
"""
8077
self.config = SageMakerAIModel.ModelConfig(
81-
endpoint_name=endpoint_name,
82-
inference_component_name=inference_component_name
78+
endpoint_name=endpoint_name, inference_component_name=inference_component_name
8379
)
8480
self.update_config(**model_config)
8581

86-
# logger.debug("endpoint=%s, config=%s | initializing", self.config["endpoint_name"], self.config)
8782
logger.debug("config=<%s> | initializing", self.config)
8883

89-
session = boto_session or boto3.Session(
90-
region_name=region_name,
91-
)
84+
if boto_session:
85+
session = boto_session
86+
elif region_name:
87+
session = boto3.Session(region_name=region_name)
88+
else:
89+
session = boto3.Session()
90+
9291
self.client = session.client(
9392
service_name="sagemaker-runtime",
9493
config=boto_client_config,
9594
)
9695

96+
def _format_message(self, message: Message, content: ContentBlock) -> dict[str, Any]:
97+
"""Format a message content block for SageMaker API.
98+
99+
Args:
100+
message: The message containing the content.
101+
content: The content block to format.
102+
103+
Returns:
104+
Formatted message for SageMaker API.
105+
"""
106+
if "text" in content:
107+
return {"role": message["role"], "content": content["text"]}
108+
109+
if "image" in content:
110+
# Convert bytes to base64 string for JSON serialization
111+
image_bytes = content["image"]["source"]["bytes"]
112+
image_bytes = image_bytes.decode("utf-8") if isinstance(image_bytes, bytes) else image_bytes
113+
return {"role": message["role"], "images": [image_bytes]}
114+
115+
if "toolUse" in content:
116+
return {
117+
"role": "assistant",
118+
"tool_calls": [
119+
{
120+
"id": content["toolUse"]["toolUseId"],
121+
"type": "function",
122+
"function": {
123+
"name": content["toolUse"]["name"],
124+
"arguments": json.dumps(content["toolUse"]["input"]),
125+
},
126+
}
127+
],
128+
}
129+
130+
if "toolResult" in content:
131+
result_content: Union[str, ImageContent, DocumentContent, Any] = None
132+
result_images = []
133+
for toolResultContent in content["toolResult"]["content"]:
134+
if "text" in toolResultContent:
135+
result_content = toolResultContent["text"]
136+
elif "json" in toolResultContent:
137+
result_content = toolResultContent["json"]
138+
elif "image" in toolResultContent:
139+
result_content = "see images"
140+
# Convert bytes to base64 string for JSON serialization
141+
image_bytes = toolResultContent["image"]["source"]["bytes"]
142+
image_bytes = image_bytes.decode("utf-8") if isinstance(image_bytes, bytes) else image_bytes
143+
result_images.append(image_bytes)
144+
else:
145+
result_content = content["toolResult"]["content"]
146+
147+
return {
148+
"role": "tool",
149+
"name": content["toolResult"]["toolUseId"],
150+
"tool_call_id": content["toolResult"]["toolUseId"],
151+
"content": json.dumps(
152+
{
153+
"result": result_content,
154+
"status": content["toolResult"]["status"],
155+
}
156+
),
157+
**({"images": result_images} if result_images else {}),
158+
}
159+
160+
return {"role": message["role"], "content": json.dumps(content)}
161+
162+
def _format_messages(self, messages: Messages) -> list[dict[str, Any]]:
163+
"""Format all messages for SageMaker API.
164+
165+
Args:
166+
messages: List of messages to format.
167+
168+
Returns:
169+
List of formatted messages.
170+
"""
171+
return [self._format_message(message, content) for message in messages for content in message["content"]]
172+
97173
@override
98174
def update_config(self, **model_config: Unpack[ModelConfig]) -> None: # type: ignore
99175
"""Update the SageMaker AI Model configuration with the provided arguments.
@@ -108,7 +184,7 @@ def get_config(self) -> ModelConfig:
108184
"""Get the SageMaker AI Model configuration.
109185
110186
Returns:
111-
The Bedrok model configuration.
187+
The SageMaker model configuration.
112188
"""
113189
return self.config
114190

@@ -126,73 +202,7 @@ def format_request(
126202
Returns:
127203
An SageMaker AI chat streaming request.
128204
"""
129-
130-
def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
131-
if "text" in content:
132-
return {"role": message["role"], "content": content["text"]}
133-
134-
if "image" in content:
135-
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
136-
image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
137-
return {
138-
"image_url": {
139-
"url": f"data:{mime_type};base64,{image_data}",
140-
},
141-
"type": "image_url",
142-
}
143-
144-
if "toolUse" in content:
145-
return {
146-
"role": "assistant",
147-
"tool_calls": [
148-
{
149-
'id': content["toolUse"]["toolUseId"],
150-
'type':'function',
151-
"function": {
152-
"name": content["toolUse"]["name"],
153-
"arguments": json.dumps(content["toolUse"]["input"]),
154-
}
155-
}
156-
],
157-
}
158-
159-
if "toolResult" in content:
160-
result_content: Union[str, ImageContent, DocumentContent, Any] = None
161-
result_images = []
162-
for toolResultContent in content["toolResult"]["content"]:
163-
if "text" in toolResultContent:
164-
result_content = toolResultContent["text"]
165-
elif "json" in toolResultContent:
166-
result_content = toolResultContent["json"]
167-
elif "image" in toolResultContent:
168-
result_content = "see images"
169-
# Convert bytes to base64 string for JSON serialization
170-
image_bytes = toolResultContent["image"]["source"]["bytes"]
171-
if isinstance(image_bytes, bytes):
172-
image_bytes = image_bytes.decode('utf-8')
173-
result_images.append(image_bytes)
174-
else:
175-
result_content = content["toolResult"]["content"]
176-
177-
return {
178-
"role": "tool",
179-
"name": content["toolResult"]["toolUseId"],
180-
"tool_call_id": content["toolResult"]["toolUseId"],
181-
"content": json.dumps(
182-
{
183-
"result": result_content,
184-
"status": content["toolResult"]["status"],
185-
}
186-
),
187-
**({"images": result_images} if result_images else {}),
188-
}
189-
190-
return {"role": message["role"], "content": json.dumps(content)}
191-
192-
def format_messages() -> list[dict[str, Any]]:
193-
return [format_message(message, content) for message in messages for content in message["content"]]
194-
195-
formatted_messages = format_messages()
205+
formatted_messages = self._format_messages(messages)
196206

197207
payload = {
198208
"messages": [
@@ -227,9 +237,9 @@ def format_messages() -> list[dict[str, Any]]:
227237
try:
228238
if message["content"] is None or message["content"] == "":
229239
message["content"] = "Thinking ..."
230-
if message['role'] == 'assistant' and message['content']=='Thinking...\n':
231-
continue
232-
except:
240+
if message["role"] == "assistant" and message["content"] == "Thinking...\n":
241+
continue
242+
except KeyError:
233243
pass
234244
messages_new.append(message)
235245
payload["messages"] = messages_new
@@ -241,11 +251,11 @@ def format_messages() -> list[dict[str, Any]]:
241251
"ContentType": "application/json",
242252
"Accept": "application/json",
243253
}
244-
254+
245255
# Add InferenceComponentName if provided
246256
if self.config.get("inference_component_name"):
247257
request["InferenceComponentName"] = self.config["inference_component_name"]
248-
258+
249259
return request
250260

251261
@override
@@ -269,7 +279,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
269279
if event["data_type"] == "text":
270280
return {"contentBlockStart": {"start": {}}}
271281
# Random string of 9 alphanumerical characters
272-
tool_id = ''.join(uuid.uuid4().hex[:9])
282+
tool_id = "".join(uuid.uuid4().hex[:9])
273283
tool_name = event["data"]["function"]["name"]
274284
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_id}}}}
275285

@@ -306,7 +316,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
306316
},
307317
}
308318
else:
309-
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
319+
raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type")
310320

311321
@override
312322
def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
@@ -329,10 +339,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
329339
# Wait until all the answer has been streamed
330340
final_response = ""
331341
for event in response["Body"]:
332-
chunk_data = event['PayloadPart']['Bytes'].decode("utf-8")
342+
chunk_data = event["PayloadPart"]["Bytes"].decode("utf-8")
333343
final_response += chunk_data
334344
final_response_json = json.loads(final_response)
335-
345+
336346
# send messages for tool execution
337347
tool_requested = False
338348
message = final_response_json["choices"][0]["message"]
@@ -348,5 +358,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
348358

349359
# Close the message
350360
yield {"chunk_type": "content_stop", "data_type": "text"}
351-
yield {"chunk_type": "message_stop", "data": "tool_use" if tool_requested else final_response_json["choices"][0]["finish_reason"]}
361+
yield {
362+
"chunk_type": "message_stop",
363+
"data": "tool_use" if tool_requested else final_response_json["choices"][0]["finish_reason"],
364+
}
352365
yield {"chunk_type": "metadata", "data": final_response_json["usage"]}

tests-integ/test_model_sagemaker.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
from strands import Agent
55
from strands.models.sagemaker import SageMakerAIModel
66

7-
import boto3
7+
ENDPOINT_NAME = "mistral-small-2501-sm-js"
8+
REGION_NAME = "us-east-1"
89

910

1011
@pytest.fixture
11-
def model(endpoint_name: str):
12-
return SageMakerAIModel(
13-
endpoint_name=endpoint_name,
14-
boto_session=boto3.Session(region_name="us-east-1"),
15-
max_tokens=1024
16-
)
12+
def model():
13+
return SageMakerAIModel(endpoint_name=ENDPOINT_NAME, region_name=REGION_NAME, max_tokens=1024)
1714

1815

1916
@pytest.fixture

0 commit comments

Comments
 (0)