Skip to content

Commit 883b4bb

Browse files
committed
Implementation complete
1 parent b728a53 commit 883b4bb

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/strands/models/sagemaker.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ class FunctionCall:
4646
arguments: Arguments to pass to the function
4747
"""
4848

49-
name: str
50-
arguments: Union[str, dict]
49+
name: Union[str, dict[Any, Any]]
50+
arguments: Union[str, dict[Any, Any]]
5151

52-
def __init__(self, **kwargs: dict):
52+
def __init__(self, **kwargs: dict[str, str]):
5353
"""Initialize function call.
5454
5555
Args:
@@ -81,7 +81,7 @@ def __init__(self, **kwargs: dict):
8181
"""
8282
self.id = str(kwargs.get("id", ""))
8383
self.type = "function"
84-
self.function = FunctionCall(**kwargs.get("function", {}))
84+
self.function = FunctionCall(**kwargs.get("function", {"name": "", "arguments": ""}))
8585

8686

8787
class SageMakerAIModel(OpenAIModel):
@@ -238,7 +238,7 @@ def format_request(
238238
payload["tool_choice"] = "auto"
239239

240240
# TODO: this should be a @override of @classmethod format_request_message
241-
for message in payload["messages"]:
241+
for message in payload["messages"]: # type: ignore
242242
# Assistant message must have either content or tool_calls, but not both
243243
if message.get("role", "") == "assistant" and message.get("tool_calls", []) != []:
244244
_ = message.pop("content")
@@ -272,7 +272,7 @@ def format_request(
272272

273273
# Add additional args if provided
274274
if self.endpoint_config.get("additional_args"):
275-
request.update(self.endpoint_config["additional_args"])
275+
request.update(self.endpoint_config["additional_args"].__dict__)
276276

277277
return request
278278

@@ -343,6 +343,9 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
343343
finish_reason = choice["finish_reason"]
344344
break
345345

346+
if choice.get("usage", None):
347+
yield {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
348+
346349
except json.JSONDecodeError:
347350
# Continue accumulating content until we have valid JSON
348351
continue
@@ -370,13 +373,6 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
370373
# Message close
371374
yield {"chunk_type": "message_stop", "data": finish_reason}
372375

373-
# Return metadata
374-
try:
375-
if choice.get("usage", None):
376-
yield {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}
377-
except Exception:
378-
pass
379-
380376
else:
381377
# Not all SageMaker AI models support streaming!
382378
response = self.client.invoke_endpoint(**request)
@@ -420,7 +416,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
420416
# Message close
421417
yield {"chunk_type": "message_stop", "data": message_stop_reason}
422418
# Handle usage metadata
423-
yield {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json["usage"])}
419+
if final_response_json.get("usage", None):
420+
yield {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))}
424421
except (
425422
self.client.exceptions.InternalFailure,
426423
self.client.exceptions.ServiceUnavailable,

0 commit comments

Comments
 (0)