Skip to content

Commit 325bae1

Browse files
authored
Merge pull request strands-agents#31 from mkmeral/fix-scripts
Fix tool calling
2 parents 30e6b1e + f4f7e4d commit 325bae1

File tree

6 files changed

+107
-71
lines changed

6 files changed

+107
-71
lines changed

src/strands/experimental/bidirectional_streaming/agent/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, AsyncIterable, Callable
1919

2020
from .... import _identifier
21+
from ....hooks.registry import HookRegistry
2122
from ....telemetry.metrics import EventLoopMetrics
2223
from ....tools.caller import _ToolCaller
2324
from ....tools.executors import ConcurrentToolExecutor
@@ -122,6 +123,7 @@ def __init__(
122123
# Initialize other components
123124
self.event_loop_metrics = EventLoopMetrics()
124125
self._tool_caller = _ToolCaller(self)
126+
self.hooks = HookRegistry()
125127

126128
# connection management
127129
self._agent_loop: "BidirectionalConnection" | None = None

src/strands/experimental/bidirectional_streaming/scripts/test_bidi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
88

9-
from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent
9+
from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent
1010
from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel
1111
from strands.experimental.bidirectional_streaming.io.audio import AudioIO
1212
from strands_tools import calculator
@@ -20,8 +20,8 @@ async def main():
2020
adapter = AudioIO()
2121
model = BidiNovaSonicModel(region="us-east-1")
2222

23-
async with BidirectionalAgent(model=model, tools=[calculator]) as agent:
24-
print("New BidirectionalAgent Experience")
23+
async with BidiAgent(model=model, tools=[calculator]) as agent:
24+
print("New BidiAgent Experience")
2525
print("Try asking: 'What is 25 times 8?' or 'Calculate the square root of 144'")
2626
await agent.run(io_channels=[adapter])
2727

src/strands/experimental/bidirectional_streaming/scripts/test_bidi_novasonic.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pyaudio
1818
from strands_tools import calculator
1919

20-
from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent
20+
from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent
2121
from strands.experimental.bidirectional_streaming.models.novasonic import BidiNovaSonicModel
2222

2323

@@ -130,23 +130,22 @@ async def receive(agent, context):
130130
"""Receive and process events from agent."""
131131
try:
132132
async for event in agent.receive():
133-
# Get event type
134133
event_type = event.get("type", "unknown")
135134

136-
# Handle audio stream events (bidirectional_audio_stream)
137-
if event_type == "bidirectional_audio_stream":
135+
# Handle audio stream events (bidi_audio_stream)
136+
if event_type == "bidi_audio_stream":
138137
if not context.get("interrupted", False):
139138
# Decode base64 audio string to bytes for playback
140139
audio_b64 = event["audio"]
141140
audio_data = base64.b64decode(audio_b64)
142141
context["audio_out"].put_nowait(audio_data)
143142

144-
# Handle interruption events (bidirectional_interruption)
145-
elif event_type == "bidirectional_interruption":
143+
# Handle interruption events (bidi_interruption)
144+
elif event_type == "bidi_interruption":
146145
context["interrupted"] = True
147146

148-
# Handle transcript events (bidirectional_transcript_stream)
149-
elif event_type == "bidirectional_transcript_stream":
147+
# Handle transcript events (bidi_transcript_stream)
148+
elif event_type == "bidi_transcript_stream":
150149
text_content = event.get("text", "")
151150
role = event.get("role", "unknown")
152151

@@ -156,10 +155,29 @@ async def receive(agent, context):
156155
elif role == "assistant":
157156
print(f"Assistant: {text_content}")
158157

159-
# Handle turn complete events (bidirectional_turn_complete)
160-
elif event_type == "bidirectional_turn_complete":
158+
# Handle response complete events (bidi_response_complete)
159+
elif event_type == "bidi_response_complete":
161160
# Reset interrupted state since the turn is complete
162161
context["interrupted"] = False
162+
163+
# Handle tool use events (tool_use_stream)
164+
elif event_type == "tool_use_stream":
165+
tool_use = event.get("current_tool_use", {})
166+
tool_name = tool_use.get("name", "unknown")
167+
tool_input = tool_use.get("input", {})
168+
print(f"🔧 Tool called: {tool_name} with input: {tool_input}")
169+
170+
# Handle tool result events (tool_result)
171+
elif event_type == "tool_result":
172+
tool_result = event.get("tool_result", {})
173+
tool_name = tool_result.get("name", "unknown")
174+
result_content = tool_result.get("content", [])
175+
result_text = ""
176+
for block in result_content:
177+
if isinstance(block, dict) and block.get("type") == "text":
178+
result_text = block.get("text", "")
179+
break
180+
print(f"✅ Tool result from {tool_name}: {result_text}")
163181

164182
except asyncio.CancelledError:
165183
pass
@@ -199,7 +217,7 @@ async def main(duration=180):
199217

200218
# Initialize model and agent
201219
model = BidiNovaSonicModel(region="us-east-1")
202-
agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.")
220+
agent = BidiAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.")
203221

204222
await agent.start()
205223

@@ -208,7 +226,7 @@ async def main(duration=180):
208226
"active": True,
209227
"audio_in": asyncio.Queue(),
210228
"audio_out": asyncio.Queue(),
211-
"connection": agent._session,
229+
"connection": agent._agent_loop,
212230
"duration": duration,
213231
"start_time": time.time(),
214232
"interrupted": False,

src/strands/experimental/bidirectional_streaming/scripts/test_bidi_openai.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pyaudio
1515
from strands_tools import calculator
1616

17-
from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent
17+
from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent
1818
from strands.experimental.bidirectional_streaming.models.openai import BidiOpenAIRealtimeModel
1919

2020

@@ -122,18 +122,18 @@ async def receive(agent, context):
122122
# Get event type
123123
event_type = event.get("type", "unknown")
124124

125-
# Handle audio stream events (bidirectional_audio_stream)
126-
if event_type == "bidirectional_audio_stream":
125+
# Handle audio stream events (bidi_audio_stream)
126+
if event_type == "bidi_audio_stream":
127127
# Decode base64 audio string to bytes for playback
128128
audio_b64 = event["audio"]
129129
audio_data = base64.b64decode(audio_b64)
130130

131131
if not context.get("interrupted", False):
132132
await context["audio_out"].put(audio_data)
133133

134-
# Handle transcript events (bidirectional_transcript_stream)
135-
elif event_type == "bidirectional_transcript_stream":
136-
source = event.get("source", "assistant")
134+
# Handle transcript events (bidi_transcript_stream)
135+
elif event_type == "bidi_transcript_stream":
136+
source = event.get("role", "assistant")
137137
text = event.get("text", "").strip()
138138

139139
if text:
@@ -142,25 +142,44 @@ async def receive(agent, context):
142142
elif source == "assistant":
143143
print(f"🔊 Assistant: {text}")
144144

145-
# Handle interruption events (bidirectional_interruption)
146-
elif event_type == "bidirectional_interruption":
145+
# Handle interruption events (bidi_interruption)
146+
elif event_type == "bidi_interruption":
147147
context["interrupted"] = True
148148
print("⚠️ Interruption detected")
149149

150-
# Handle session start events (bidirectional_session_start)
151-
elif event_type == "bidirectional_session_start":
150+
# Handle connection start events (bidi_connection_start)
151+
elif event_type == "bidi_connection_start":
152152
print(f"✓ Session started: {event.get('model', 'unknown')}")
153153

154-
# Handle session end events (bidirectional_session_end)
155-
elif event_type == "bidirectional_session_end":
154+
# Handle connection close events (bidi_connection_close)
155+
elif event_type == "bidi_connection_close":
156156
print(f"✓ Session ended: {event.get('reason', 'unknown')}")
157157
context["active"] = False
158158
break
159159

160-
# Handle turn complete events (bidirectional_turn_complete)
161-
elif event_type == "bidirectional_turn_complete":
160+
# Handle response complete events (bidi_response_complete)
161+
elif event_type == "bidi_response_complete":
162162
# Reset interrupted state since the turn is complete
163163
context["interrupted"] = False
164+
165+
# Handle tool use events (tool_use_stream)
166+
elif event_type == "tool_use_stream":
167+
tool_use = event.get("current_tool_use", {})
168+
tool_name = tool_use.get("name", "unknown")
169+
tool_input = tool_use.get("input", {})
170+
print(f"🔧 Tool called: {tool_name} with input: {tool_input}")
171+
172+
# Handle tool result events (tool_result)
173+
elif event_type == "tool_result":
174+
tool_result = event.get("tool_result", {})
175+
tool_name = tool_result.get("name", "unknown")
176+
result_content = tool_result.get("content", [])
177+
result_text = ""
178+
for block in result_content:
179+
if isinstance(block, dict) and block.get("type") == "text":
180+
result_text = block.get("text", "")
181+
break
182+
print(f"✅ Tool result from {tool_name}: {result_text}")
164183

165184
except asyncio.CancelledError:
166185
pass
@@ -246,7 +265,7 @@ async def main():
246265
)
247266

248267
# Create agent
249-
agent = BidirectionalAgent(
268+
agent = BidiAgent(
250269
model=model,
251270
tools=[calculator],
252271
system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect."

src/strands/experimental/bidirectional_streaming/scripts/test_gemini_live.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,11 @@
3737
import pyaudio
3838
from strands_tools import calculator
3939

40-
from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent
40+
from strands.experimental.bidirectional_streaming.agent.agent import BidiAgent
4141
from strands.experimental.bidirectional_streaming.models.gemini_live import BidiGeminiLiveModel
4242

43-
# Configure logging - debug only for Gemini Live, info for everything else
43+
# Configure logging
4444
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
45-
gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live')
46-
gemini_logger.setLevel(logging.DEBUG)
4745
logger = logging.getLogger(__name__)
4846

4947

@@ -145,58 +143,57 @@ async def receive(agent, context):
145143
"""Receive and process events from agent."""
146144
try:
147145
async for event in agent.receive():
148-
# Debug: Log event type and keys
149146
event_type = event.get("type", "unknown")
150-
event_keys = list(event.keys())
151-
logger.debug(f"Received event type: {event_type}, keys: {event_keys}")
152147

153-
# Handle audio stream events (bidirectional_audio_stream)
154-
if event_type == "bidirectional_audio_stream":
148+
# Handle audio stream events (bidi_audio_stream)
149+
if event_type == "bidi_audio_stream":
155150
if not context.get("interrupted", False):
156151
# Decode base64 audio string to bytes for playback
157152
audio_b64 = event["audio"]
158153
audio_data = base64.b64decode(audio_b64)
159154
context["audio_out"].put_nowait(audio_data)
160-
logger.info(f"🔊 Audio queued for playback: {len(audio_data)} bytes")
161155

162-
# Handle interruption events (bidirectional_interruption)
163-
elif event_type == "bidirectional_interruption":
156+
# Handle interruption events (bidi_interruption)
157+
elif event_type == "bidi_interruption":
164158
context["interrupted"] = True
165-
logger.info("Interruption detected")
159+
print("⚠️ Interruption detected")
166160

167-
# Handle transcript events (bidirectional_transcript_stream)
168-
elif event_type == "bidirectional_transcript_stream":
161+
# Handle transcript events (bidi_transcript_stream)
162+
elif event_type == "bidi_transcript_stream":
169163
transcript_text = event.get("text", "")
170-
transcript_source = event.get("source", "unknown")
164+
transcript_role = event.get("role", "unknown")
171165
is_final = event.get("is_final", False)
172166

173167
# Print transcripts with special formatting
174-
if transcript_source == "user":
168+
if transcript_role == "user":
175169
print(f"🎤 User: {transcript_text}")
176-
elif transcript_source == "assistant":
170+
elif transcript_role == "assistant":
177171
print(f"🔊 Assistant: {transcript_text}")
178172

179-
# Handle turn complete events (bidirectional_turn_complete)
180-
elif event_type == "bidirectional_turn_complete":
181-
logger.debug("Turn complete - model ready for next input")
182-
# Reset interrupted state since the turn is complete
173+
# Handle response complete events (bidi_response_complete)
174+
elif event_type == "bidi_response_complete":
175+
# Reset interrupted state since the response is complete
183176
context["interrupted"] = False
184177

185-
# Handle session start events (bidirectional_session_start)
186-
elif event_type == "bidirectional_session_start":
187-
logger.info(f"Session started: {event.get('model', 'unknown')}")
178+
# Handle tool use events (tool_use_stream)
179+
elif event_type == "tool_use_stream":
180+
tool_use = event.get("current_tool_use", {})
181+
tool_name = tool_use.get("name", "unknown")
182+
tool_input = tool_use.get("input", {})
183+
print(f"🔧 Tool called: {tool_name} with input: {tool_input}")
188184

189-
# Handle session end events (bidirectional_session_end)
190-
elif event_type == "bidirectional_session_end":
191-
logger.info(f"Session ended: {event.get('reason', 'unknown')}")
192-
193-
# Handle error events (bidirectional_error)
194-
elif event_type == "bidirectional_error":
195-
logger.error(f"Error: {event.get('error_message', 'unknown')}")
196-
197-
# Handle turn start events (bidirectional_turn_start)
198-
elif event_type == "bidirectional_turn_start":
199-
logger.debug(f"Turn started: {event.get('response_id', 'unknown')}")
185+
# Handle tool result events (tool_result)
186+
elif event_type == "tool_result":
187+
tool_result = event.get("tool_result", {})
188+
tool_name = tool_result.get("name", "unknown")
189+
result_content = tool_result.get("content", [])
190+
# Extract text from content blocks
191+
result_text = ""
192+
for block in result_content:
193+
if isinstance(block, dict) and block.get("type") == "text":
194+
result_text = block.get("text", "")
195+
break
196+
print(f"✅ Tool result from {tool_name}: {result_text}")
200197

201198
except asyncio.CancelledError:
202199
pass
@@ -325,7 +322,7 @@ async def main(duration=180):
325322
logger.info("Gemini Live model initialized successfully")
326323
print("Using Gemini Live model")
327324

328-
agent = BidirectionalAgent(
325+
agent = BidiAgent(
329326
model=model,
330327
tools=[calculator],
331328
system_prompt="You are a helpful assistant."
@@ -338,7 +335,7 @@ async def main(duration=180):
338335
"active": True,
339336
"audio_in": asyncio.Queue(),
340337
"audio_out": asyncio.Queue(),
341-
"connection": agent._session,
338+
"connection": agent._agent_loop,
342339
"duration": duration,
343340
"start_time": time.time(),
344341
"interrupted": False,

src/strands/types/_events.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class ToolUseStreamEvent(ModelStreamEvent):
145145

146146
def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None:
147147
"""Initialize with delta and current tool use state."""
148-
super().__init__({"delta": delta, "current_tool_use": current_tool_use})
148+
super().__init__({"type": "tool_use_stream", "delta": delta, "current_tool_use": current_tool_use})
149149

150150

151151
class TextStreamEvent(ModelStreamEvent):
@@ -281,7 +281,7 @@ def __init__(self, tool_result: ToolResult) -> None:
281281
Args:
282282
tool_result: Final result from the tool execution
283283
"""
284-
super().__init__({"tool_result": tool_result})
284+
super().__init__({"type": "tool_result", "tool_result": tool_result})
285285

286286
@property
287287
def tool_use_id(self) -> str:
@@ -309,7 +309,7 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None:
309309
tool_use: The tool invocation producing the stream
310310
tool_stream_data: The yielded event from the tool execution
311311
"""
312-
super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}})
312+
super().__init__({"type": "tool_stream", "tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}})
313313

314314
@property
315315
def tool_use_id(self) -> str:

0 commit comments

Comments
 (0)