Skip to content

Commit 585627e

Browse files
committed
build - bidi - isolate nova provider
1 parent 2944abf commit 585627e

File tree

21 files changed

+223
-238
lines changed

21 files changed

+223
-238
lines changed

pyproject.toml

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,18 @@ a2a = [
7070
"starlette>=0.46.2,<1.0.0",
7171
]
7272

73-
bidi = [
74-
"aws_sdk_bedrock_runtime; python_version>='3.12'",
73+
bidi-io = [
7574
"prompt_toolkit>=3.0.0,<4.0.0",
7675
"pyaudio>=0.2.13,<1.0.0",
77-
"smithy-aws-core>=0.0.1; python_version>='3.12'",
7876
]
7977
bidi-gemini = ["google-genai>=1.32.0,<2.0.0"]
78+
bidi-nova = [
79+
"aws_sdk_bedrock_runtime; python_version>='3.12'",
80+
"smithy-aws-core>=0.0.1; python_version>='3.12'",
81+
]
8082
bidi-openai = ["websockets>=15.0.0,<16.0.0"]
8183

82-
all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
83-
bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"]
84+
all = ["strands-agents[a2a,anthropic,bidi-io,bidi-gemini,bidi-openai,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
8485

8586
dev = [
8687
"commitizen>=4.4.0,<5.0.0",
@@ -130,7 +131,7 @@ format-fix = [
130131
]
131132
lint-check = [
132133
"ruff check",
133-
"mypy ./src"
134+
"mypy -p src"
134135
]
135136
lint-fix = [
136137
"ruff check --fix"
@@ -204,16 +205,10 @@ warn_no_return = true
204205
warn_unreachable = true
205206
follow_untyped_imports = true
206207
ignore_missing_imports = false
207-
exclude = ["src/strands/experimental/bidi"]
208-
209-
[[tool.mypy.overrides]]
210-
module = ["strands.experimental.bidi.*"]
211-
follow_imports = "skip"
212208

213209
[tool.ruff]
214210
line-length = 120
215211
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
216-
exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"]
217212

218213
[tool.ruff.lint]
219214
select = [
@@ -236,16 +231,14 @@ convention = "google"
236231
[tool.pytest.ini_options]
237232
testpaths = ["tests"]
238233
asyncio_default_fixture_loop_scope = "function"
239-
addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi"
240-
234+
addopts = "--ignore=tests/strands/experimental/bidi/models/test_nova_sonic.py --ignore=tests_integ/bidi"
241235

242236
[tool.coverage.run]
243237
branch = true
244238
source = ["src"]
245239
context = "thread"
246240
parallel = true
247241
concurrency = ["thread", "multiprocessing"]
248-
omit = ["src/strands/experimental/bidi/*"]
249242

250243
[tool.coverage.report]
251244
show_missing = true
@@ -275,48 +268,3 @@ style = [
275268
["text", ""],
276269
["disabled", "fg:#858585 italic"]
277270
]
278-
279-
# =========================
280-
# Bidi development configs
281-
# =========================
282-
283-
[tool.hatch.envs.bidi]
284-
dev-mode = true
285-
features = ["dev", "bidi-all"]
286-
installer = "uv"
287-
288-
[tool.hatch.envs.bidi.scripts]
289-
prepare = [
290-
"hatch run bidi-lint:format-fix",
291-
"hatch run bidi-lint:quality-fix",
292-
"hatch run bidi-lint:type-check",
293-
"hatch run bidi-test:test-cov",
294-
]
295-
296-
[tools.hatch.envs.bidi-lint]
297-
template = "bidi"
298-
299-
[tool.hatch.envs.bidi-lint.scripts]
300-
format-check = "format-fix --check"
301-
format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
302-
quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
303-
quality-fix = "quality-check --fix"
304-
type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py"
305-
306-
[tool.hatch.envs.bidi-test]
307-
template = "bidi"
308-
309-
[tool.hatch.envs.bidi-test.scripts]
310-
test = "pytest {args} tests/strands/experimental/bidi"
311-
test-cov = """
312-
test \
313-
--cov=strands.experimental.bidi \
314-
--cov-config= \
315-
--cov-branch \
316-
--cov-report=term-missing \
317-
--cov-report=xml:build/coverage/bidi-coverage.xml \
318-
--cov-report=html:build/coverage/bidi-html
319-
"""
320-
321-
[[tool.hatch.envs.bidi-test.matrix]]
322-
python = ["3.13", "3.12"]

src/strands/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This module implements experimental features that are subject to change in future revisions without notice.
44
"""
55

6-
from . import steering, tools
6+
from . import bidi, steering, tools
77
from .agent_config import config_to_agent
88

9-
__all__ = ["config_to_agent", "tools", "steering"]
9+
__all__ = ["bidi", "config_to_agent", "tools", "steering"]

src/strands/experimental/bidi/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
"""Bidirectional streaming package."""
22

3-
import sys
4-
5-
if sys.version_info < (3, 12):
6-
raise ImportError("bidi only supported for >= Python 3.12")
7-
83
# Main components - Primary user interface
94
# Re-export standard agent events for tool handling
105
from ...types._events import (
@@ -19,7 +14,6 @@
1914

2015
# Model interface (for custom implementations)
2116
from .models.model import BidiModel
22-
from .models.nova_sonic import BidiNovaSonicModel
2317

2418
# Built-in tools
2519
from .tools import stop_conversation
@@ -48,8 +42,6 @@
4842
"BidiAgent",
4943
# IO channels
5044
"BidiAudioIO",
51-
# Model providers
52-
"BidiNovaSonicModel",
5345
# Built-in tools
5446
"stop_conversation",
5547
# Input Event types

src/strands/experimental/bidi/_async/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
1616
funcs: Stop functions to call in sequence.
1717
1818
Raises:
19-
ExceptionGroup: If any stop function raises an exception.
19+
RuntimeError: If any stop function raises an exception.
2020
"""
2121
exceptions = []
2222
for func in funcs:
@@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
2626
exceptions.append(exception)
2727

2828
if exceptions:
29-
raise ExceptionGroup("failed stop sequence", exceptions)
29+
exceptions.append(RuntimeError("failed stop sequence"))
30+
for i in range(1, len(exceptions)):
31+
exceptions[i].__cause__ = exceptions[i - 1]
32+
33+
raise exceptions[-1]

src/strands/experimental/bidi/agent/agent.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from ...tools import ToolProvider
3333
from .._async import stop_all
3434
from ..models.model import BidiModel
35-
from ..models.nova_sonic import BidiNovaSonicModel
3635
from ..types.agent import BidiAgentInput
3736
from ..types.events import (
3837
BidiAudioInputEvent,
@@ -100,13 +99,13 @@ def __init__(
10099
ValueError: If model configuration is invalid or state is invalid type.
101100
TypeError: If model type is unsupported.
102101
"""
103-
self.model = (
104-
BidiNovaSonicModel()
105-
if not model
106-
else BidiNovaSonicModel(model_id=model)
107-
if isinstance(model, str)
108-
else model
109-
)
102+
if isinstance(model, BidiModel):
103+
self.model = model
104+
else:
105+
from ..models.nova_sonic import BidiNovaSonicModel
106+
107+
self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel()
108+
110109
self.system_prompt = system_prompt
111110
self.messages = messages or []
112111

@@ -390,9 +389,16 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
390389
for start in [*input_starts, *output_starts]:
391390
await start(self)
392391

393-
async with asyncio.TaskGroup() as task_group:
394-
inputs_task = task_group.create_task(run_inputs())
395-
task_group.create_task(run_outputs(inputs_task))
392+
inputs_task = asyncio.create_task(run_inputs())
393+
outputs_task = asyncio.create_task(run_outputs(inputs_task))
394+
395+
try:
396+
await asyncio.gather(inputs_task, outputs_task)
397+
except (Exception, asyncio.CancelledError):
398+
inputs_task.cancel()
399+
outputs_task.cancel()
400+
await asyncio.gather(inputs_task, outputs_task, return_exceptions=True)
401+
raise
396402

397403
finally:
398404
input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)]
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""Bidirectional model interfaces and implementations."""
22

33
from .model import BidiModel, BidiModelTimeoutError
4-
from .nova_sonic import BidiNovaSonicModel
54

65
__all__ = [
76
"BidiModel",
87
"BidiModelTimeoutError",
9-
"BidiNovaSonicModel",
108
]

src/strands/experimental/bidi/models/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""
1515

1616
import logging
17-
from typing import Any, AsyncIterable, Protocol
17+
from typing import Any, AsyncIterable, Protocol, runtime_checkable
1818

1919
from ....types._events import ToolResultEvent
2020
from ....types.content import Messages
@@ -27,6 +27,7 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30+
@runtime_checkable
3031
class BidiModel(Protocol):
3132
"""Protocol for bidirectional streaming models.
3233

src/strands/experimental/bidi/models/nova_sonic.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,41 @@
1111
- Tool execution with content containers and identifier tracking
1212
- 8-minute connection limits with proper cleanup sequences
1313
- Interruption detection through stopReason events
14+
15+
Note, BidiNovaSonicModel is only supported for Python 3.12+
1416
"""
1517

16-
import asyncio
18+
import sys
19+
20+
if sys.version_info < (3, 12):
21+
raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+")
22+
23+
import asyncio # type: ignore[unreachable]
1724
import base64
1825
import json
1926
import logging
2027
import uuid
2128
from typing import Any, AsyncGenerator, cast
2229

2330
import boto3
24-
from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput
25-
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
26-
from aws_sdk_bedrock_runtime.models import (
31+
from aws_sdk_bedrock_runtime.client import ( # type: ignore[import-not-found]
32+
BedrockRuntimeClient,
33+
InvokeModelWithBidirectionalStreamOperationInput,
34+
)
35+
from aws_sdk_bedrock_runtime.config import ( # type: ignore[import-not-found]
36+
Config,
37+
HTTPAuthSchemeResolver,
38+
SigV4AuthScheme,
39+
)
40+
from aws_sdk_bedrock_runtime.models import ( # type: ignore[import-not-found]
2741
BidirectionalInputPayloadPart,
2842
InvokeModelWithBidirectionalStreamInputChunk,
2943
ModelTimeoutException,
3044
ValidationException,
3145
)
32-
from smithy_aws_core.identity.static import StaticCredentialsResolver
33-
from smithy_core.aio.eventstream import DuplexEventStream
34-
from smithy_core.shapes import ShapeID
46+
from smithy_aws_core.identity.static import StaticCredentialsResolver # type: ignore[import-not-found]
47+
from smithy_core.aio.eventstream import DuplexEventStream # type: ignore[import-not-found]
48+
from smithy_core.shapes import ShapeID # type: ignore[import-not-found]
3549

3650
from ....types._events import ToolResultEvent, ToolUseStreamEvent
3751
from ....types.content import Messages
@@ -93,6 +107,8 @@ class BidiNovaSonicModel(BidiModel):
93107
Manages Nova Sonic's complex event sequencing, audio format conversion, and
94108
tool execution patterns while providing the standard BidiModel interface.
95109
110+
Note, BidiNovaSonicModel is only supported for Python 3.12+.
111+
96112
Attributes:
97113
_stream: open bedrock stream to nova sonic.
98114
"""

src/strands/tools/_caller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import json
1111
import random
12-
from typing import TYPE_CHECKING, Any, Callable
12+
from typing import TYPE_CHECKING, Any, Callable, cast
1313

1414
from .._async import run_async
1515
from ..tools.executors._executor import ToolExecutor
@@ -108,7 +108,7 @@ async def acall() -> ToolResult:
108108

109109
# Apply conversation management if agent supports it (traditional agents)
110110
if hasattr(self._agent, "conversation_manager"):
111-
self._agent.conversation_manager.apply_management(self._agent)
111+
self._agent.conversation_manager.apply_management(cast("Agent", self._agent))
112112

113113
return tool_result
114114

0 commit comments

Comments
 (0)