@@ -46,10 +46,10 @@ class FunctionCall:
4646 arguments: Arguments to pass to the function
4747 """
4848
49- name : str
50- arguments : Union [str , dict ]
49+ name : Union [ str , dict [ Any , Any ]]
50+ arguments : Union [str , dict [ Any , Any ] ]
5151
52- def __init__ (self , ** kwargs : dict ):
52+ def __init__ (self , ** kwargs : dict [ str , str ] ):
5353 """Initialize function call.
5454
5555 Args:
@@ -81,7 +81,7 @@ def __init__(self, **kwargs: dict):
8181 """
8282 self .id = str (kwargs .get ("id" , "" ))
8383 self .type = "function"
84- self .function = FunctionCall (** kwargs .get ("function" , {}))
84+ self .function = FunctionCall (** kwargs .get ("function" , {"name" : "" , "arguments" : "" }))
8585
8686
8787class SageMakerAIModel (OpenAIModel ):
@@ -238,7 +238,7 @@ def format_request(
238238 payload ["tool_choice" ] = "auto"
239239
240240 # TODO: this should be a @override of @classmethod format_request_message
241- for message in payload ["messages" ]:
241+ for message in payload ["messages" ]: # type: ignore
242242 # Assistant message must have either content or tool_calls, but not both
243243 if message .get ("role" , "" ) == "assistant" and message .get ("tool_calls" , []) != []:
244244 _ = message .pop ("content" )
@@ -272,7 +272,7 @@ def format_request(
272272
273273 # Add additional args if provided
274274 if self .endpoint_config .get ("additional_args" ):
275- request .update (self .endpoint_config ["additional_args" ])
275+ request .update (self .endpoint_config ["additional_args" ]. __dict__ )
276276
277277 return request
278278
@@ -343,6 +343,9 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
343343 finish_reason = choice ["finish_reason" ]
344344 break
345345
346+ if choice .get ("usage" , None ):
347+ yield {"chunk_type" : "metadata" , "data" : UsageMetadata (** choice ["usage" ])}
348+
346349 except json .JSONDecodeError :
347350 # Continue accumulating content until we have valid JSON
348351 continue
@@ -370,13 +373,6 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
370373 # Message close
371374 yield {"chunk_type" : "message_stop" , "data" : finish_reason }
372375
373- # Return metadata
374- try :
375- if choice .get ("usage" , None ):
376- yield {"chunk_type" : "metadata" , "data" : UsageMetadata (** choice ["usage" ])}
377- except Exception :
378- pass
379-
380376 else :
381377 # Not all SageMaker AI models support streaming!
382378 response = self .client .invoke_endpoint (** request )
@@ -420,7 +416,8 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any]
420416 # Message close
421417 yield {"chunk_type" : "message_stop" , "data" : message_stop_reason }
422418 # Handle usage metadata
423- yield {"chunk_type" : "metadata" , "data" : UsageMetadata (** final_response_json ["usage" ])}
419+ if final_response_json .get ("usage" , None ):
420+ yield {"chunk_type" : "metadata" , "data" : UsageMetadata (** final_response_json .get ("usage" , None ))}
424421 except (
425422 self .client .exceptions .InternalFailure ,
426423 self .client .exceptions .ServiceUnavailable ,
0 commit comments