Skip to content

Commit 12e3037

Browse files
committed
tests: add more tests
1 parent da58123 commit 12e3037

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

src/strands/tools/mcp/mcp_client.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from concurrent import futures
1717
from datetime import timedelta
1818
from types import TracebackType
19-
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union
19+
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast
2020

2121
from mcp import ClientSession, ListToolsResult
2222
from mcp.types import CallToolResult as MCPCallToolResult
@@ -157,7 +157,7 @@ def stop(
157157
async def _set_close_event() -> None:
158158
self._close_event.set()
159159

160-
asyncio.run_coroutine_threadsafe(_set_close_event(), self._background_thread_event_loop)
160+
self._invoke_on_background_thread(_set_close_event()).result()
161161

162162
self._log_debug_with_thread("waiting for background thread to join")
163163
self._background_thread.join()
@@ -183,8 +183,7 @@ def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedLi
183183
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
184184

185185
async def _list_tools_async() -> ListToolsResult:
186-
assert self._background_thread_session is not None
187-
return await self._background_thread_session.list_tools(cursor=pagination_token)
186+
return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token)
188187

189188
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
190189
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))
@@ -210,8 +209,7 @@ def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromp
210209
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
211210

212211
async def _list_prompts_async() -> ListPromptsResult:
213-
assert self._background_thread_session is not None
214-
return await self._background_thread_session.list_prompts(cursor=pagination_token)
212+
return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token)
215213

216214
list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
217215
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
@@ -235,8 +233,7 @@ def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResu
235233
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
236234

237235
async def _get_prompt_async() -> GetPromptResult:
238-
assert self._background_thread_session is not None
239-
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
236+
return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args)
240237

241238
get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
242239
self._log_debug_with_thread("received prompt from MCP server")
@@ -271,8 +268,9 @@ def call_tool_sync(
271268
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
272269

273270
async def _call_tool_async() -> MCPCallToolResult:
274-
assert self._background_thread_session is not None
275-
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
271+
return await cast(ClientSession, self._background_thread_session).call_tool(
272+
name, arguments, read_timeout_seconds
273+
)
276274

277275
try:
278276
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result()
@@ -307,8 +305,9 @@ async def call_tool_async(
307305
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
308306

309307
async def _call_tool_async() -> MCPCallToolResult:
310-
assert self._background_thread_session is not None
311-
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)
308+
return await cast(ClientSession, self._background_thread_session).call_tool(
309+
name, arguments, read_timeout_seconds
310+
)
312311

313312
try:
314313
future = self._invoke_on_background_thread(_call_tool_async())

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,39 @@ def slow_transport():
485485
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 1 seconds"):
486486
client.start()
487487
mock_stop.assert_called_once_with(None, None, None)
488+
489+
490+
def test_stop_with_no_background_thread():
491+
"""Test that stop() handles the case when no background thread exists."""
492+
client = MCPClient(MagicMock())
493+
494+
# Ensure no background thread exists
495+
assert client._background_thread is None
496+
497+
# Mock join to verify it's not called
498+
with patch("threading.Thread.join") as mock_join:
499+
client.stop(None, None, None)
500+
mock_join.assert_not_called()
501+
502+
# Verify cleanup occurred
503+
assert client._background_thread is None
504+
505+
506+
def test_stop_with_background_thread_but_no_event_loop():
507+
"""Test that stop() handles the case when background thread exists but event loop is None."""
508+
client = MCPClient(MagicMock())
509+
510+
# Mock a background thread without event loop
511+
mock_thread = MagicMock()
512+
mock_thread.join = MagicMock()
513+
client._background_thread = mock_thread
514+
client._background_thread_event_loop = None
515+
516+
# Should not raise any exceptions and should join the thread
517+
client.stop(None, None, None)
518+
519+
# Verify thread was joined
520+
mock_thread.join.assert_called_once()
521+
522+
# Verify cleanup occurred
523+
assert client._background_thread is None

0 commit comments

Comments
 (0)