Skip to content

Commit 60bd291

Browse files
authored
bidi - remove python 3.11+ features (#1302)
1 parent a64a851 commit 60bd291

File tree

7 files changed

+136
-13
lines changed

7 files changed

+136
-13
lines changed

src/strands/experimental/bidi/_async/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from typing import Awaitable, Callable
44

5+
from ._task_group import _TaskGroup
56
from ._task_pool import _TaskPool
67

7-
__all__ = ["_TaskPool"]
8+
__all__ = ["_TaskGroup", "_TaskPool"]
89

910

1011
async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
@@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
1617
funcs: Stop functions to call in sequence.
1718
1819
Raises:
19-
ExceptionGroup: If any stop function raises an exception.
20+
RuntimeError: If any stop function raises an exception.
2021
"""
2122
exceptions = []
2223
for func in funcs:
2324
try:
2425
await func()
2526
except Exception as exception:
26-
exceptions.append(exception)
27+
exceptions.append({"func_name": func.__name__, "exception": repr(exception)})
2728

2829
if exceptions:
29-
raise ExceptionGroup("failed stop sequence", exceptions)
30+
raise RuntimeError(f"exceptions={exceptions} | failed stop sequence")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Manage a group of async tasks.
2+
3+
This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11.
4+
5+
- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups
6+
"""
7+
8+
import asyncio
9+
from typing import Any, Coroutine
10+
11+
12+
class _TaskGroup:
13+
"""Shim of asyncio.TaskGroup for use in Python 3.10.
14+
15+
Attributes:
16+
_tasks: List of tasks in group.
17+
"""
18+
19+
_tasks: list[asyncio.Task]
20+
21+
def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
22+
"""Create an async task and add to group.
23+
24+
Returns:
25+
The created task.
26+
"""
27+
task = asyncio.create_task(coro)
28+
self._tasks.append(task)
29+
return task
30+
31+
async def __aenter__(self) -> "_TaskGroup":
32+
"""Setup self managed task group context."""
33+
self._tasks = []
34+
return self
35+
36+
async def __aexit__(self, *_: Any) -> None:
37+
"""Execute tasks in group.
38+
39+
The following execution rules are enforced:
40+
- The context stops executing all tasks if at least one task raises an Exception or the context is cancelled.
41+
- The context re-raises Exceptions to the caller.
42+
- The context re-raises CancelledErrors to the caller only if the context itself was cancelled.
43+
"""
44+
try:
45+
await asyncio.gather(*self._tasks)
46+
47+
except (Exception, asyncio.CancelledError) as error:
48+
for task in self._tasks:
49+
task.cancel()
50+
51+
await asyncio.gather(*self._tasks, return_exceptions=True)
52+
53+
if not isinstance(error, asyncio.CancelledError):
54+
raise
55+
56+
context_task = asyncio.current_task()
57+
if context_task and context_task.cancelling() > 0: # context itself was cancelled
58+
raise
59+
60+
finally:
61+
self._tasks = []

src/strands/experimental/bidi/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ....types.tools import AgentTool
3131
from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent
3232
from ...tools import ToolProvider
33-
from .._async import stop_all
33+
from .._async import _TaskGroup, stop_all
3434
from ..models.model import BidiModel
3535
from ..models.nova_sonic import BidiNovaSonicModel
3636
from ..types.agent import BidiAgentInput
@@ -390,7 +390,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
390390
for start in [*input_starts, *output_starts]:
391391
await start(self)
392392

393-
async with asyncio.TaskGroup() as task_group:
393+
async with _TaskGroup() as task_group:
394394
inputs_task = task_group.create_task(run_inputs())
395395
task_group.create_task(run_outputs(inputs_task))
396396

tests/strands/experimental/bidi/_async/test__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@ async def test_stop_exception():
1010
func1 = AsyncMock()
1111
func2 = AsyncMock(side_effect=ValueError("stop 2 failed"))
1212
func3 = AsyncMock()
13+
func4 = AsyncMock(side_effect=ValueError("stop 4 failed"))
1314

14-
with pytest.raises(ExceptionGroup) as exc_info:
15-
await stop_all(func1, func2, func3)
15+
with pytest.raises(Exception, match=r"failed stop sequence") as exc_info:
16+
await stop_all(func1, func2, func3, func4)
1617

1718
func1.assert_called_once()
1819
func2.assert_called_once()
1920
func3.assert_called_once()
21+
func4.assert_called_once()
2022

21-
assert len(exc_info.value.exceptions) == 1
22-
with pytest.raises(ValueError, match=r"stop 2 failed"):
23-
raise exc_info.value.exceptions[0]
23+
tru_message = str(exc_info.value)
24+
assert "ValueError('stop 2 failed')" in tru_message
25+
assert "ValueError('stop 4 failed')" in tru_message
2426

2527

2628
@pytest.mark.asyncio
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import asyncio
2+
import unittest.mock
3+
4+
import pytest
5+
6+
from strands.experimental.bidi._async._task_group import _TaskGroup
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_task_group__aexit__():
11+
coro = unittest.mock.AsyncMock()
12+
13+
async with _TaskGroup() as task_group:
14+
task_group.create_task(coro())
15+
16+
coro.assert_called_once()
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_task_group__aexit__exception():
21+
wait_event = asyncio.Event()
22+
async def wait():
23+
await wait_event.wait()
24+
25+
async def fail():
26+
raise ValueError("test error")
27+
28+
with pytest.raises(ValueError, match=r"test error"):
29+
async with _TaskGroup() as task_group:
30+
wait_task = task_group.create_task(wait())
31+
fail_task = task_group.create_task(fail())
32+
33+
assert wait_task.cancelled()
34+
assert not fail_task.cancelled()
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_task_group__aexit__cancelled():
39+
wait_event = asyncio.Event()
40+
async def wait():
41+
await wait_event.wait()
42+
43+
tasks = []
44+
45+
run_event = asyncio.Event()
46+
async def run():
47+
async with _TaskGroup() as task_group:
48+
tasks.append(task_group.create_task(wait()))
49+
run_event.set()
50+
51+
run_task = asyncio.create_task(run())
52+
await run_event.wait()
53+
run_task.cancel()
54+
55+
with pytest.raises(asyncio.CancelledError):
56+
await run_task
57+
58+
wait_task = tasks[0]
59+
assert wait_task.cancelled()

tests/strands/experimental/bidi/models/test_gemini_live.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id):
185185
model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key})
186186
await model4.start()
187187
mock_live_session_cm.__aexit__.side_effect = Exception("Close failed")
188-
with pytest.raises(ExceptionGroup):
188+
with pytest.raises(Exception, match=r"failed stop sequence"):
189189
await model4.stop()
190190

191191

tests/strands/experimental/bidi/models/test_openai_realtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ async def async_connect(*args, **kwargs):
353353
model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key})
354354
await model4.start()
355355
mock_ws.close.side_effect = Exception("Close failed")
356-
with pytest.raises(ExceptionGroup):
356+
with pytest.raises(Exception, match=r"failed stop sequence"):
357357
await model4.stop()
358358

359359

0 commit comments

Comments
 (0)