Skip to content

Commit da58123

Browse files
committed
fix(mcp): auto cleanup on exceptions occurring in __enter__
1 parent 9213bc5 commit da58123

File tree

4 files changed

+84
-12
lines changed

4 files changed

+84
-12
lines changed

src/strands/models/bedrock.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
DEFAULT_READ_TIMEOUT = 120
4949

50+
5051
class BedrockModel(Model):
5152
"""AWS Bedrock model provider implementation.
5253

src/strands/tools/mcp/mcp_client.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,15 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti
8383
self._transport_callable = transport_callable
8484

8585
self._background_thread: threading.Thread | None = None
86-
self._background_thread_session: ClientSession
87-
self._background_thread_event_loop: AbstractEventLoop
86+
self._background_thread_session: ClientSession | None = None
87+
self._background_thread_event_loop: AbstractEventLoop | None = None
8888

8989
def __enter__(self) -> "MCPClient":
90-
"""Context manager entry point which initializes the MCP server connection."""
90+
"""Context manager entry point which initializes the MCP server connection.
91+
92+
TODO: Refactor to lazy initialization pattern following idiomatic Python.
93+
Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead.
94+
"""
9195
return self.start()
9296

9397
def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None:
@@ -118,9 +122,15 @@ def start(self) -> "MCPClient":
118122
self._init_future.result(timeout=self._startup_timeout)
119123
self._log_debug_with_thread("the client initialization was successful")
120124
except futures.TimeoutError as e:
121-
raise MCPClientInitializationError("background thread did not start in 30 seconds") from e
125+
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
126+
self.stop(None, None, None)
127+
raise MCPClientInitializationError(
128+
f"background thread did not start in {self._startup_timeout} seconds"
129+
) from e
122130
except Exception as e:
123131
logger.exception("client failed to initialize")
132+
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
133+
self.stop(None, None, None)
124134
raise MCPClientInitializationError("the client initialization failed") from e
125135
return self
126136

@@ -129,21 +139,29 @@ def stop(
129139
) -> None:
130140
"""Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources.
131141
142+
This method is defensive and can handle partial initialization states that may occur
143+
if start() fails partway through initialization.
144+
132145
Args:
133146
exc_type: Exception type if an exception was raised in the context
134147
exc_val: Exception value if an exception was raised in the context
135148
exc_tb: Exception traceback if an exception was raised in the context
136149
"""
137150
self._log_debug_with_thread("exiting MCPClient context")
138151

139-
async def _set_close_event() -> None:
140-
self._close_event.set()
141-
142-
self._invoke_on_background_thread(_set_close_event()).result()
143-
self._log_debug_with_thread("waiting for background thread to join")
152+
# Only try to signal close event if we have a background thread
144153
if self._background_thread is not None:
154+
# Signal close event if event loop exists
155+
if self._background_thread_event_loop is not None:
156+
157+
async def _set_close_event() -> None:
158+
self._close_event.set()
159+
160+
asyncio.run_coroutine_threadsafe(_set_close_event(), self._background_thread_event_loop)
161+
162+
self._log_debug_with_thread("waiting for background thread to join")
145163
self._background_thread.join()
146-
self._log_debug_with_thread("background thread joined, MCPClient context exited")
164+
self._log_debug_with_thread("background thread joined, MCPClient context exited")
147165

148166
# Reset fields to allow instance reuse
149167
self._init_future = futures.Future()
@@ -165,6 +183,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi
165183
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
166184

167185
async def _list_tools_async() -> ListToolsResult:
186+
assert self._background_thread_session is not None
168187
return await self._background_thread_session.list_tools(cursor=pagination_token)
169188

170189
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
@@ -191,6 +210,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp
191210
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
192211

193212
async def _list_prompts_async() -> ListPromptsResult:
213+
assert self._background_thread_session is not None
194214
return await self._background_thread_session.list_prompts(cursor=pagination_token)
195215

196216
list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
@@ -215,6 +235,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu
215235
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
216236

217237
async def _get_prompt_async() -> GetPromptResult:
238+
assert self._background_thread_session is not None
218239
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
219240

220241
get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
@@ -250,6 +271,7 @@ def call_tool_sync(
250271
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
251272

252273
async def _call_tool_async() -> MCPCallToolResult:
274+
assert self._background_thread_session is not None
253275
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
254276

255277
try:
@@ -285,6 +307,7 @@ async def call_tool_async(
285307
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
286308

287309
async def _call_tool_async() -> MCPCallToolResult:
310+
assert self._background_thread_session is not None
288311
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
289312

290313
try:

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,12 @@ def test_enter_with_initialization_exception(mock_transport):
337337

338338
client = MCPClient(mock_transport["transport_callable"])
339339

340-
with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
341-
client.start()
340+
with patch.object(client, "stop") as mock_stop:
341+
with pytest.raises(MCPClientInitializationError, match="the client initialization failed"):
342+
client.start()
343+
344+
# Verify stop() was called for cleanup
345+
mock_stop.assert_called_once_with(None, None, None)
342346

343347

344348
def test_mcp_tool_result_type():
@@ -466,3 +470,18 @@ def test_get_prompt_sync_session_not_active():
466470

467471
with pytest.raises(MCPClientInitializationError, match="client session is not running"):
468472
client.get_prompt_sync("test_prompt_id", {})
473+
474+
475+
def test_timeout_initialization_cleanup():
476+
"""Test that timeout during initialization properly cleans up."""
477+
478+
def slow_transport():
479+
time.sleep(5)
480+
return MagicMock()
481+
482+
client = MCPClient(slow_transport, startup_timeout=1)
483+
484+
with patch.object(client, "stop") as mock_stop:
485+
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"):
486+
client.start()
487+
mock_stop.assert_called_once_with(None, None, None)

tests_integ/test_mcp_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from strands.tools.mcp.mcp_client import MCPClient
1616
from strands.tools.mcp.mcp_types import MCPTransport
1717
from strands.types.content import Message
18+
from strands.types.exceptions import MCPClientInitializationError
1819
from strands.types.tools import ToolUse
1920

2021

@@ -268,3 +269,31 @@ def transport_callback() -> MCPTransport:
268269

269270
def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
270271
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]
272+
273+
274+
def test_mcp_client_timeout_integration():
275+
"""Integration test for timeout scenario that caused hanging."""
276+
import threading
277+
278+
from mcp import StdioServerParameters, stdio_client
279+
280+
def slow_transport():
281+
time.sleep(4) # Longer than timeout
282+
return stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
283+
284+
client = MCPClient(slow_transport, startup_timeout=2)
285+
initial_threads = threading.active_count()
286+
287+
# First attempt should timeout
288+
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"):
289+
with client:
290+
pass
291+
292+
time.sleep(1) # Allow cleanup
293+
assert threading.active_count() == initial_threads # No thread leak
294+
295+
# Should be able to recover by increasing timeout
296+
client._startup_timeout = 60
297+
with client:
298+
tools = client.list_tools_sync()
299+
assert len(tools) >= 0 # Should work now

0 commit comments

Comments
 (0)