Skip to content

Commit b4e91db

Browse files
committed
multi agent input (strands-agents#1196)
1 parent 783b988 commit b4e91db

File tree

7 files changed

+72
-21
lines changed

7 files changed

+72
-21
lines changed

src/strands/multiagent/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from .._async import run_async
1414
from ..agent import AgentResult
15-
from ..types.content import ContentBlock
1615
from ..types.event_loop import Metrics, Usage
16+
from ..types.multiagent import MultiAgentInput
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -173,7 +173,7 @@ class MultiAgentBase(ABC):
173173

174174
@abstractmethod
175175
async def invoke_async(
176-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
176+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
177177
) -> MultiAgentResult:
178178
"""Invoke asynchronously.
179179
@@ -186,7 +186,7 @@ async def invoke_async(
186186
raise NotImplementedError("invoke_async not implemented")
187187

188188
async def stream_async(
189-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
189+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
190190
) -> AsyncIterator[dict[str, Any]]:
191191
"""Stream events during multi-agent execution.
192192
@@ -211,7 +211,7 @@ async def stream_async(
211211
yield {"result": result}
212212

213213
def __call__(
214-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
214+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
215215
) -> MultiAgentResult:
216216
"""Invoke synchronously.
217217

src/strands/multiagent/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..types.content import ContentBlock, Messages
4848
from ..types.event_loop import Metrics, Usage
49+
from ..types.multiagent import MultiAgentInput
4950
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
5051

5152
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class GraphState:
6869
"""
6970

7071
# Task (with default empty string)
71-
task: str | list[ContentBlock] = ""
72+
task: MultiAgentInput = ""
7273

7374
# Execution state
7475
status: Status = Status.PENDING
@@ -457,7 +458,7 @@ def __init__(
457458
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
458459

459460
def __call__(
460-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
461+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
461462
) -> GraphResult:
462463
"""Invoke the graph synchronously.
463464
@@ -473,7 +474,7 @@ def __call__(
473474
return run_async(lambda: self.invoke_async(task, invocation_state))
474475

475476
async def invoke_async(
476-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
477+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
477478
) -> GraphResult:
478479
"""Invoke the graph asynchronously.
479480
@@ -497,7 +498,7 @@ async def invoke_async(
497498
return cast(GraphResult, final_event["result"])
498499

499500
async def stream_async(
500-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
501+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
501502
) -> AsyncIterator[dict[str, Any]]:
502503
"""Stream events during graph execution.
503504

src/strands/multiagent/swarm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..types.content import ContentBlock, Messages
4848
from ..types.event_loop import Metrics, Usage
49+
from ..types.multiagent import MultiAgentInput
4950
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
5051

5152
logger = logging.getLogger(__name__)
@@ -146,7 +147,7 @@ class SwarmState:
146147
"""Current state of swarm execution."""
147148

148149
current_node: SwarmNode | None # The agent currently executing
149-
task: str | list[ContentBlock] # The original task from the user that is being executed
150+
task: MultiAgentInput # The original task from the user that is being executed
150151
completion_status: Status = Status.PENDING # Current swarm execution status
151152
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
152153
node_history: list[SwarmNode] = field(default_factory=list) # Complete history of agents that have executed
@@ -278,7 +279,7 @@ def __init__(
278279
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
279280

280281
def __call__(
281-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
282+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
282283
) -> SwarmResult:
283284
"""Invoke the swarm synchronously.
284285
@@ -293,7 +294,7 @@ def __call__(
293294
return run_async(lambda: self.invoke_async(task, invocation_state))
294295

295296
async def invoke_async(
296-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
297+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
297298
) -> SwarmResult:
298299
"""Invoke the swarm asynchronously.
299300
@@ -317,7 +318,7 @@ async def invoke_async(
317318
return cast(SwarmResult, final_event["result"])
318319

319320
async def stream_async(
320-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
321+
self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any
321322
) -> AsyncIterator[dict[str, Any]]:
322323
"""Stream events during swarm execution.
323324
@@ -756,7 +757,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
756757
)
757758

758759
async def _execute_node(
759-
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
760+
self, node: SwarmNode, task: MultiAgentInput, invocation_state: dict[str, Any]
760761
) -> AsyncIterator[Any]:
761762
"""Execute swarm node and yield TypedEvent objects."""
762763
start_time = time.time()

src/strands/telemetry/tracer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import logging
99
import os
1010
from datetime import date, datetime, timezone
11-
from typing import Any, Dict, Mapping, Optional
11+
from typing import Any, Dict, Mapping, Optional, cast
1212

1313
import opentelemetry.trace as trace_api
1414
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
1515
from opentelemetry.trace import Span, StatusCode
1616

1717
from ..agent.agent_result import AgentResult
1818
from ..types.content import ContentBlock, Message, Messages
19+
from ..types.interrupt import InterruptResponseContent
20+
from ..types.multiagent import MultiAgentInput
1921
from ..types.streaming import Metrics, StopReason, Usage
2022
from ..types.tools import ToolResult, ToolUse
2123
from ..types.traces import Attributes, AttributeValue
@@ -675,7 +677,7 @@ def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]
675677

676678
def start_multiagent_span(
677679
self,
678-
task: str | list[ContentBlock],
680+
task: MultiAgentInput,
679681
instance: str,
680682
) -> Span:
681683
"""Start a new span for swarm invocation."""
@@ -789,12 +791,23 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None:
789791
{"content": serialize(message["content"])},
790792
)
791793

792-
def _map_content_blocks_to_otel_parts(self, content_blocks: list[ContentBlock]) -> list[dict[str, Any]]:
793-
"""Map ContentBlock objects to OpenTelemetry parts format."""
794+
def _map_content_blocks_to_otel_parts(
795+
self, content_blocks: list[ContentBlock] | list[InterruptResponseContent]
796+
) -> list[dict[str, Any]]:
797+
"""Map content blocks to OpenTelemetry parts format."""
794798
parts: list[dict[str, Any]] = []
795799

796-
for block in content_blocks:
797-
if "text" in block:
800+
for block in cast(list[dict[str, Any]], content_blocks):
801+
if "interruptResponse" in block:
802+
interrupt_response = block["interruptResponse"]
803+
parts.append(
804+
{
805+
"type": "interrupt_response",
806+
"id": interrupt_response["interruptId"],
807+
"response": interrupt_response["response"],
808+
},
809+
)
810+
elif "text" in block:
798811
# Standard TextPart
799812
parts.append({"type": "text", "content": block["text"]})
800813
elif "toolUse" in block:

src/strands/types/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
from typing import TypeAlias
77

88
from .content import ContentBlock, Messages
9-
from .interrupt import InterruptResponse
9+
from .interrupt import InterruptResponseContent
1010

11-
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None
11+
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None

src/strands/types/multiagent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Multi-agent related type definitions for the SDK."""
2+
3+
from typing import TypeAlias
4+
5+
from .content import ContentBlock
6+
7+
MultiAgentInput: TypeAlias = str | list[ContentBlock]

tests/strands/telemetry/test_tracer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
1313
from strands.types.content import ContentBlock
14+
from strands.types.interrupt import InterruptResponseContent
1415
from strands.types.streaming import Metrics, StopReason, Usage
1516

1617

@@ -396,6 +397,34 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer):
396397
assert span is not None
397398

398399

400+
@pytest.mark.parametrize(
401+
"task, expected_parts",
402+
[
403+
([ContentBlock(text="Test message")], [{"type": "text", "content": "Test message"}]),
404+
(
405+
[InterruptResponseContent(interruptResponse={"interruptId": "test-id", "response": "approved"})],
406+
[{"type": "interrupt_response", "id": "test-id", "response": "approved"}],
407+
),
408+
],
409+
)
410+
def test_start_multiagent_span_task_part_conversion(mock_tracer, task, expected_parts, monkeypatch):
411+
monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental")
412+
413+
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
414+
tracer = Tracer()
415+
tracer.tracer = mock_tracer
416+
417+
mock_span = mock.MagicMock()
418+
mock_tracer.start_span.return_value = mock_span
419+
420+
tracer.start_multiagent_span(task, "swarm")
421+
422+
expected_content = json.dumps([{"role": "user", "parts": expected_parts}])
423+
mock_span.add_event.assert_any_call(
424+
"gen_ai.client.inference.operation.details", attributes={"gen_ai.input.messages": expected_content}
425+
)
426+
427+
399428
def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch):
400429
"""Test starting a swarm call span with task as list of contentBlock with latest semantic conventions."""
401430
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):

0 commit comments

Comments
 (0)