diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index abef0b732..d47550742 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -20,7 +20,12 @@ from deprecated import deprecated from .itransport import ITransport -from .protocol import ToolSchema +from .mcp_transport import ( + McpHttpTransportV20241105, + McpHttpTransportV20250326, + McpHttpTransportV20250618, +) +from .protocol import Protocol, ToolSchema from .tool import ToolboxTool from .toolbox_transport import ToolboxTransport from .utils import identify_auth_requirements, resolve_value @@ -44,6 +49,7 @@ def __init__( client_headers: Optional[ Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]] ] = None, + protocol: Protocol = Protocol.TOOLBOX, ): """ Initializes the ToolboxClient. @@ -54,8 +60,21 @@ def __init__( If None (default), a new session is created internally. Note that if a session is provided, its lifecycle (including closing) should typically be managed externally. - client_headers: Headers to include in each request sent through this client. + client_headers: Headers to include in each request sent through this + client. + protocol: The communication protocol to use. """ + if protocol == Protocol.TOOLBOX: + self.__transport = ToolboxTransport(url, session) + elif protocol in Protocol.get_supported_mcp_versions(): + if protocol == Protocol.MCP_v20250618: + self.__transport = McpHttpTransportV20250618(url, session, protocol) + elif protocol == Protocol.MCP_v20250326: + self.__transport = McpHttpTransportV20250326(url, session, protocol) + elif protocol == Protocol.MCP_v20241105: + self.__transport = McpHttpTransportV20241105(url, session, protocol) + else: + raise ValueError(f"Unsupported MCP protocol version: {protocol}") self.__transport = ToolboxTransport(url, session) self.__client_headers = client_headers if client_headers is not None else {} diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py new file mode 100644 index 000000000..8813dc523 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .v20241105.mcp import McpHttpTransportV20241105 +from .v20250326.mcp import McpHttpTransportV20250326 +from .v20250618.mcp import McpHttpTransportV20250618 + +__all__ = [ + "McpHttpTransportV20241105", + "McpHttpTransportV20250326", + "McpHttpTransportV20250618", +] diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py new file mode 100644 index 000000000..d6e2ca67c --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from abc import ABC, abstractmethod +from typing import Optional + +from aiohttp import ClientSession + +from ..itransport import ITransport +from ..protocol import ( + AdditionalPropertiesSchema, + ParameterSchema, + Protocol, + ToolSchema, +) + + +class _McpHttpTransportBase(ITransport, ABC): + """Base transport for MCP protocols.""" + + def __init__( + self, + base_url: str, + session: Optional[ClientSession] = None, + protocol: Protocol = Protocol.MCP, + ): + self._mcp_base_url = f"{base_url}/mcp/" + self._protocol_version = protocol.value + self._server_version: Optional[str] = None + + self._manage_session = session is None + self._session = session or ClientSession() + self._init_lock = asyncio.Lock() + self._init_task: Optional[asyncio.Task] = None + + async def _ensure_initialized(self): + """Ensures the session is initialized before making requests.""" + async with self._init_lock: + if self._init_task is None: + self._init_task = asyncio.create_task(self._initialize_session()) + await self._init_task + + @property + def base_url(self) -> str: + return self._mcp_base_url + + def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: + """Converts a raw MCP tool dictionary into the Toolbox ToolSchema.""" + parameters = [] + input_schema = tool_data.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required = input_schema.get("required", []) + + for name, schema in properties.items(): + additional_props = schema.get("additionalProperties") + if isinstance(additional_props, dict): + additional_props = AdditionalPropertiesSchema( + type=additional_props["type"] + ) + else: + additional_props = True + parameters.append( + ParameterSchema( + name=name, + type=schema["type"], + description=schema.get("description", ""), + required=name in required, + additionalProperties=additional_props, + ) + ) + + return ToolSchema( + description=tool_data.get("description") or "", parameters=parameters + ) + + async def close(self): + async with self._init_lock: + if self._init_task: + try: + await self._init_task + except Exception: + # If initialization failed, we can still try to close. + pass + if self._manage_session and self._session and not self._session.closed: + await self._session.close() + + @abstractmethod + async def _initialize_session(self): + """Initializes the MCP session.""" + pass diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py new file mode 100644 index 000000000..9e609a1a3 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py @@ -0,0 +1,168 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Mapping, Optional, TypeVar + +from pydantic import BaseModel + +from ... import version +from ...protocol import ManifestSchema +from ..transport_base import _McpHttpTransportBase +from . import types + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +class McpHttpTransportV20241105(_McpHttpTransportBase): + """Transport for the MCP v2024-11-05 protocol.""" + + async def _send_request( + self, + url: str, + request: types.MCPRequest[ReceiveResultT] | types.MCPNotification, + headers: Optional[Mapping[str, str]] = None, + ) -> ReceiveResultT | None: + """Sends a JSON-RPC request to the MCP server.""" + params = ( + request.params.model_dump(mode="json", exclude_none=True) + if isinstance(request.params, BaseModel) + else request.params + ) + rpc_msg: BaseModel + if isinstance(request, types.MCPNotification): + rpc_msg = types.JSONRPCNotification(method=request.method, params=params) + else: + rpc_msg = types.JSONRPCRequest(method=request.method, params=params) + + payload = rpc_msg.model_dump(mode="json", exclude_none=True) + + async with self._session.post( + url, json=payload, headers=dict(headers or {}) + ) as response: + if not response.ok: + error_text = await response.text() + raise RuntimeError( + f"API request failed with status {response.status} " + f"({response.reason}). Server response: {error_text}" + ) + + if response.status == 204 or response.content.at_eof(): + return None + + json_resp = await response.json() + + # Check for JSON-RPC Error + if "error" in json_resp: + try: + err = types.JSONRPCError.model_validate(json_resp).error + raise RuntimeError( + f"MCP request failed with code {err.code}: {err.message}" + ) + except Exception: + raise RuntimeError(f"MCP request failed: {json_resp.get('error')}") + + # Parse Result + if isinstance(request, types.MCPRequest): + try: + rpc_resp = types.JSONRPCResponse.model_validate(json_resp) + return request.get_result_model().model_validate(rpc_resp.result) + except Exception as e: + raise RuntimeError(f"Failed to parse JSON-RPC response: {e}") + return None + + async def _initialize_session(self): + """Initializes the MCP session.""" + params = types.InitializeRequestParams( + protocolVersion=self._protocol_version, + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation( + name="toolbox-python-sdk", version=version.__version__ + ), + ) + + result = await self._send_request( + url=self._mcp_base_url, request=types.InitializeRequest(params=params) + ) + + self._server_version = result.serverInfo.version + if result.protocolVersion != self._protocol_version: + raise RuntimeError( + f"MCP version mismatch: client does not support server version {result.protocolVersion}" + ) + if not result.capabilities.tools: + if self._manage_session: + await self.close() + raise RuntimeError("Server does not support the 'tools' capability.") + + await self._send_request( + url=self._mcp_base_url, request=types.InitializedNotification() + ) + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self._ensure_initialized() + + url = self._mcp_base_url + (toolset_name if toolset_name else "") + result = await self._send_request( + url=url, request=types.ListToolsRequest(), headers=headers + ) + if result is None: + raise RuntimeError("Failed to list tools: No response from server.") + + tools_map = { + t.name: self._convert_tool_schema(t.model_dump(mode="json", by_alias=True)) + for t in result.tools + } + if self._server_version is None: + raise RuntimeError("Server version not available.") + + return ManifestSchema(serverVersion=self._server_version, tools=tools_map) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + manifest = await self.tools_list(headers=headers) + + if tool_name not in manifest.tools: + raise ValueError(f"Tool '{tool_name}' not found.") + + return ManifestSchema( + serverVersion=manifest.serverVersion, + tools={tool_name: manifest.tools[tool_name]}, + ) + + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]] + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self._ensure_initialized() + + result = await self._send_request( + url=self._mcp_base_url, + request=types.CallToolRequest( + params=types.CallToolRequestParams(name=tool_name, arguments=arguments) + ), + headers=headers, + ) + if result is None: + raise RuntimeError( + f"Failed to invoke tool '{tool_name}': No response from server." + ) + + return "".join(c.text for c in result.content if c.type == "text") or "null" diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/types.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/types.py new file mode 100644 index 000000000..5cfca277a --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/types.py @@ -0,0 +1,160 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, Generic, Literal, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + + +class _BaseMCPModel(BaseModel): + """Base model with common configuration.""" + + model_config = ConfigDict(extra="allow") + + +class RequestParams(_BaseMCPModel): + pass + + +class JSONRPCRequest(_BaseMCPModel): + jsonrpc: Literal["2.0"] = "2.0" + id: str | int = Field(default_factory=lambda: str(uuid.uuid4())) + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(_BaseMCPModel): + """A notification which does not expect a response (no ID).""" + + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + result: dict[str, Any] + + +class ErrorData(_BaseMCPModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + + +class BaseMetadata(_BaseMCPModel): + name: str + + +class Implementation(BaseMetadata): + version: str + + +class ClientCapabilities(_BaseMCPModel): + pass + + +class InitializeRequestParams(RequestParams): + protocolVersion: str + capabilities: ClientCapabilities + clientInfo: Implementation + + +class ServerCapabilities(_BaseMCPModel): + prompts: dict[str, Any] | None = None + tools: dict[str, Any] | None = None + + +class InitializeResult(_BaseMCPModel): + protocolVersion: str + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + + +class Tool(BaseMetadata): + description: str | None = None + inputSchema: dict[str, Any] + + +class ListToolsResult(_BaseMCPModel): + tools: list[Tool] + + +class TextContent(_BaseMCPModel): + type: Literal["text"] + text: str + + +class CallToolResult(_BaseMCPModel): + content: list[TextContent] + isError: bool = False + + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class MCPRequest(_BaseMCPModel, Generic[ResultT]): + method: str + params: dict[str, Any] | BaseModel | None = None + + def get_result_model(self) -> Type[ResultT]: + raise NotImplementedError + + +class MCPNotification(_BaseMCPModel): + method: str + params: dict[str, Any] | BaseModel | None = None + + +class InitializeRequest(MCPRequest[InitializeResult]): + method: Literal["initialize"] = "initialize" + params: InitializeRequestParams + + def get_result_model(self) -> Type[InitializeResult]: + return InitializeResult + + +class InitializedNotification(MCPNotification): + method: Literal["notifications/initialized"] = "notifications/initialized" + params: dict[str, Any] = {} + + +class ListToolsRequest(MCPRequest[ListToolsResult]): + method: Literal["tools/list"] = "tools/list" + params: dict[str, Any] = {} + + def get_result_model(self) -> Type[ListToolsResult]: + return ListToolsResult + + +class CallToolRequestParams(_BaseMCPModel): + name: str + arguments: dict[str, Any] + + +class CallToolRequest(MCPRequest[CallToolResult]): + method: Literal["tools/call"] = "tools/call" + params: CallToolRequestParams + + def get_result_model(self) -> Type[CallToolResult]: + return CallToolResult diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py new file mode 100644 index 000000000..60b87594d --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py @@ -0,0 +1,205 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Mapping, Optional, TypeVar + +from pydantic import BaseModel + +from ... import version +from ...protocol import ManifestSchema +from ..transport_base import _McpHttpTransportBase +from . import types + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +class McpHttpTransportV20250326(_McpHttpTransportBase): + """Transport for the MCP v2025-03-26 protocol.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._session_id: Optional[str] = None + + async def _send_request( + self, + url: str, + request: types.MCPRequest[ReceiveResultT] | types.MCPNotification, + headers: Optional[Mapping[str, str]] = None, + ) -> ReceiveResultT | None: + """Sends a JSON-RPC request to the MCP server.""" + raw_params = ( + request.params.model_dump(mode="json", exclude_none=True) + if isinstance(request.params, BaseModel) + else request.params + ) + + # Inject Session ID if available (v2025-03-26 specific) + params = raw_params + if request.method != "initialize" and self._session_id: + if params is None: + params = {} + elif isinstance(params, dict): + params = params.copy() + params["Mcp-Session-Id"] = self._session_id + + rpc_msg: BaseModel + if isinstance(request, types.MCPNotification): + rpc_msg = types.JSONRPCNotification(method=request.method, params=params) + else: + rpc_msg = types.JSONRPCRequest(method=request.method, params=params) + + payload = rpc_msg.model_dump(mode="json", exclude_none=True) + + async with self._session.post( + url, json=payload, headers=dict(headers or {}) + ) as response: + if not response.ok: + error_text = await response.text() + raise RuntimeError( + "API request failed with status" + f" {response.status} ({response.reason}). Server response:" + f" {error_text}" + ) + + if response.status == 204 or response.content.at_eof(): + return None + + json_resp = await response.json() + + # Check for JSON-RPC Error + if "error" in json_resp: + try: + err = types.JSONRPCError.model_validate(json_resp).error + raise RuntimeError( + f"MCP request failed with code {err.code}: {err.message}" + ) + except Exception: + # Fallback if the error doesn't match our schema exactly + raw_error = json_resp.get("error", {}) + raise RuntimeError(f"MCP request failed: {raw_error}") + + # Parse Result + if isinstance(request, types.MCPRequest): + try: + rpc_resp = types.JSONRPCResponse.model_validate(json_resp) + return request.get_result_model().model_validate(rpc_resp.result) + except Exception as e: + raise RuntimeError(f"Failed to parse JSON-RPC response: {e}") + return None + + async def _initialize_session(self): + """Initializes the MCP session.""" + params = types.InitializeRequestParams( + protocolVersion=self._protocol_version, + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation( + name="toolbox-python-sdk", version=version.__version__ + ), + ) + + result = await self._send_request( + url=self._mcp_base_url, + request=types.InitializeRequest(params=params), + ) + + self._server_version = result.serverInfo.version + + if result.protocolVersion != self._protocol_version: + raise RuntimeError( + "MCP version mismatch: client does not support server version" + f" {result.protocolVersion}" + ) + + if not result.capabilities.tools: + if self._manage_session: + await self.close() + raise RuntimeError("Server does not support the 'tools' capability.") + + # Extract session ID from extra fields (v2025-03-26 specific) + extra = result.model_extra or {} + self._session_id = extra.get("Mcp-Session-Id") + + if not self._session_id: + if self._manage_session: + await self.close() + raise RuntimeError( + "Server did not return a Mcp-Session-Id during initialization." + ) + + await self._send_request( + url=self._mcp_base_url, + request=types.InitializedNotification(), + ) + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self._ensure_initialized() + + url = self._mcp_base_url + (toolset_name if toolset_name else "") + result = await self._send_request( + url=url, request=types.ListToolsRequest(), headers=headers + ) + if result is None: + raise RuntimeError("Failed to list tools: No response from server.") + + tools_map = { + t.name: self._convert_tool_schema(t.model_dump(mode="json", by_alias=True)) + for t in result.tools + } + if self._server_version is None: + raise RuntimeError("Server version not available.") + + return ManifestSchema( + serverVersion=self._server_version, + tools=tools_map, + ) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + manifest = await self.tools_list(headers=headers) + + if tool_name not in manifest.tools: + raise ValueError(f"Tool '{tool_name}' not found.") + + return ManifestSchema( + serverVersion=manifest.serverVersion, + tools={tool_name: manifest.tools[tool_name]}, + ) + + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]] + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self._ensure_initialized() + + result = await self._send_request( + url=self._mcp_base_url, + request=types.CallToolRequest( + params=types.CallToolRequestParams(name=tool_name, arguments=arguments) + ), + headers=headers, + ) + + if result is None: + raise RuntimeError( + f"Failed to invoke tool '{tool_name}': No response from server." + ) + + return "".join(c.text for c in result.content if c.type == "text") or "null" diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/types.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/types.py new file mode 100644 index 000000000..5cfca277a --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/types.py @@ -0,0 +1,160 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, Generic, Literal, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + + +class _BaseMCPModel(BaseModel): + """Base model with common configuration.""" + + model_config = ConfigDict(extra="allow") + + +class RequestParams(_BaseMCPModel): + pass + + +class JSONRPCRequest(_BaseMCPModel): + jsonrpc: Literal["2.0"] = "2.0" + id: str | int = Field(default_factory=lambda: str(uuid.uuid4())) + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(_BaseMCPModel): + """A notification which does not expect a response (no ID).""" + + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + result: dict[str, Any] + + +class ErrorData(_BaseMCPModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + + +class BaseMetadata(_BaseMCPModel): + name: str + + +class Implementation(BaseMetadata): + version: str + + +class ClientCapabilities(_BaseMCPModel): + pass + + +class InitializeRequestParams(RequestParams): + protocolVersion: str + capabilities: ClientCapabilities + clientInfo: Implementation + + +class ServerCapabilities(_BaseMCPModel): + prompts: dict[str, Any] | None = None + tools: dict[str, Any] | None = None + + +class InitializeResult(_BaseMCPModel): + protocolVersion: str + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + + +class Tool(BaseMetadata): + description: str | None = None + inputSchema: dict[str, Any] + + +class ListToolsResult(_BaseMCPModel): + tools: list[Tool] + + +class TextContent(_BaseMCPModel): + type: Literal["text"] + text: str + + +class CallToolResult(_BaseMCPModel): + content: list[TextContent] + isError: bool = False + + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class MCPRequest(_BaseMCPModel, Generic[ResultT]): + method: str + params: dict[str, Any] | BaseModel | None = None + + def get_result_model(self) -> Type[ResultT]: + raise NotImplementedError + + +class MCPNotification(_BaseMCPModel): + method: str + params: dict[str, Any] | BaseModel | None = None + + +class InitializeRequest(MCPRequest[InitializeResult]): + method: Literal["initialize"] = "initialize" + params: InitializeRequestParams + + def get_result_model(self) -> Type[InitializeResult]: + return InitializeResult + + +class InitializedNotification(MCPNotification): + method: Literal["notifications/initialized"] = "notifications/initialized" + params: dict[str, Any] = {} + + +class ListToolsRequest(MCPRequest[ListToolsResult]): + method: Literal["tools/list"] = "tools/list" + params: dict[str, Any] = {} + + def get_result_model(self) -> Type[ListToolsResult]: + return ListToolsResult + + +class CallToolRequestParams(_BaseMCPModel): + name: str + arguments: dict[str, Any] + + +class CallToolRequest(MCPRequest[CallToolResult]): + method: Literal["tools/call"] = "tools/call" + params: CallToolRequestParams + + def get_result_model(self) -> Type[CallToolResult]: + return CallToolResult diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py new file mode 100644 index 000000000..1cae01f7d --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py @@ -0,0 +1,184 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Mapping, Optional, TypeVar + +from pydantic import BaseModel + +from ... import version +from ...protocol import ManifestSchema +from ..transport_base import _McpHttpTransportBase +from . import types + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +class McpHttpTransportV20250618(_McpHttpTransportBase): + """Transport for the MCP v2025-06-18 protocol.""" + + async def _send_request( + self, + url: str, + request: types.MCPRequest[ReceiveResultT] | types.MCPNotification, + headers: Optional[Mapping[str, str]] = None, + ) -> ReceiveResultT | None: + """Sends a JSON-RPC request to the MCP server.""" + req_headers = dict(headers or {}) + req_headers["MCP-Protocol-Version"] = self._protocol_version + + params = ( + request.params.model_dump(mode="json", exclude_none=True) + if isinstance(request.params, BaseModel) + else request.params + ) + + rpc_msg: BaseModel + if isinstance(request, types.MCPNotification): + rpc_msg = types.JSONRPCNotification(method=request.method, params=params) + else: + rpc_msg = types.JSONRPCRequest(method=request.method, params=params) + + payload = rpc_msg.model_dump(mode="json", exclude_none=True) + + async with self._session.post( + url, json=payload, headers=req_headers + ) as response: + if not response.ok: + error_text = await response.text() + raise RuntimeError( + "API request failed with status" + f" {response.status} ({response.reason}). Server response:" + f" {error_text}" + ) + + if response.status == 204 or response.content.at_eof(): + return None + + json_resp = await response.json() + + # Check for JSON-RPC Error + if "error" in json_resp: + try: + err = types.JSONRPCError.model_validate(json_resp).error + raise RuntimeError( + f"MCP request failed with code {err.code}: {err.message}" + ) + except Exception: + # Fallback if the error doesn't match our schema exactly + raw_error = json_resp.get("error", {}) + raise RuntimeError(f"MCP request failed: {raw_error}") + + # Parse Result + if isinstance(request, types.MCPRequest): + try: + rpc_resp = types.JSONRPCResponse.model_validate(json_resp) + return request.get_result_model().model_validate(rpc_resp.result) + except Exception as e: + raise RuntimeError(f"Failed to parse JSON-RPC response: {e}") + return None + + async def _initialize_session(self): + """Initializes the MCP session.""" + params = types.InitializeRequestParams( + protocolVersion=self._protocol_version, + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation( + name="toolbox-python-sdk", version=version.__version__ + ), + ) + + result = await self._send_request( + url=self._mcp_base_url, + request=types.InitializeRequest(params=params), + ) + + self._server_version = result.serverInfo.version + + if result.protocolVersion != self._protocol_version: + raise RuntimeError( + "MCP version mismatch: client does not support server version" + f" {result.protocolVersion}" + ) + + if not result.capabilities.tools: + if self._manage_session: + await self.close() + raise RuntimeError("Server does not support the 'tools' capability.") + + await self._send_request( + url=self._mcp_base_url, + request=types.InitializedNotification(), + ) + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self._ensure_initialized() + + url = self._mcp_base_url + (toolset_name if toolset_name else "") + result = await self._send_request( + url=url, request=types.ListToolsRequest(), headers=headers + ) + if result is None: + raise RuntimeError("Failed to list tools: No response from server.") + + tools_map = { + t.name: self._convert_tool_schema(t.model_dump(mode="json", by_alias=True)) + for t in result.tools + } + if self._server_version is None: + raise RuntimeError("Server version not available.") + + return ManifestSchema( + serverVersion=self._server_version, + tools=tools_map, + ) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + manifest = await self.tools_list(headers=headers) + + if tool_name not in manifest.tools: + raise ValueError(f"Tool '{tool_name}' not found.") + + return ManifestSchema( + serverVersion=manifest.serverVersion, + tools={tool_name: manifest.tools[tool_name]}, + ) + + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]] + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self._ensure_initialized() + + result = await self._send_request( + url=self._mcp_base_url, + request=types.CallToolRequest( + params=types.CallToolRequestParams(name=tool_name, arguments=arguments) + ), + headers=headers, + ) + + if result is None: + raise RuntimeError( + f"Failed to invoke tool '{tool_name}': No response from server." + ) + + return "".join(c.text for c in result.content if c.type == "text") or "null" diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/types.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/types.py new file mode 100644 index 000000000..5cfca277a --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/types.py @@ -0,0 +1,160 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, Generic, Literal, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + + +class _BaseMCPModel(BaseModel): + """Base model with common configuration.""" + + model_config = ConfigDict(extra="allow") + + +class RequestParams(_BaseMCPModel): + pass + + +class JSONRPCRequest(_BaseMCPModel): + jsonrpc: Literal["2.0"] = "2.0" + id: str | int = Field(default_factory=lambda: str(uuid.uuid4())) + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(_BaseMCPModel): + """A notification which does not expect a response (no ID).""" + + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + result: dict[str, Any] + + +class ErrorData(_BaseMCPModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + + +class BaseMetadata(_BaseMCPModel): + name: str + + +class Implementation(BaseMetadata): + version: str + + +class ClientCapabilities(_BaseMCPModel): + pass + + +class InitializeRequestParams(RequestParams): + protocolVersion: str + capabilities: ClientCapabilities + clientInfo: Implementation + + +class ServerCapabilities(_BaseMCPModel): + prompts: dict[str, Any] | None = None + tools: dict[str, Any] | None = None + + +class InitializeResult(_BaseMCPModel): + protocolVersion: str + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + + +class Tool(BaseMetadata): + description: str | None = None + inputSchema: dict[str, Any] + + +class ListToolsResult(_BaseMCPModel): + tools: list[Tool] + + +class TextContent(_BaseMCPModel): + type: Literal["text"] + text: str + + +class CallToolResult(_BaseMCPModel): + content: list[TextContent] + isError: bool = False + + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class MCPRequest(_BaseMCPModel, Generic[ResultT]): + method: str + params: dict[str, Any] | BaseModel | None = None + + def get_result_model(self) -> Type[ResultT]: + raise NotImplementedError + + +class MCPNotification(_BaseMCPModel): + method: str + params: dict[str, Any] | BaseModel | None = None + + +class InitializeRequest(MCPRequest[InitializeResult]): + method: Literal["initialize"] = "initialize" + params: InitializeRequestParams + + def get_result_model(self) -> Type[InitializeResult]: + return InitializeResult + + +class InitializedNotification(MCPNotification): + method: Literal["notifications/initialized"] = "notifications/initialized" + params: dict[str, Any] = {} + + +class ListToolsRequest(MCPRequest[ListToolsResult]): + method: Literal["tools/list"] = "tools/list" + params: dict[str, Any] = {} + + def get_result_model(self) -> Type[ListToolsResult]: + return ListToolsResult + + +class CallToolRequestParams(_BaseMCPModel): + name: str + arguments: dict[str, Any] + + +class CallToolRequest(MCPRequest[CallToolResult]): + method: Literal["tools/call"] = "tools/call" + params: CallToolRequestParams + + def get_result_model(self) -> Type[CallToolResult]: + return CallToolResult diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index 8cf563cc1..c58caf1d8 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -11,12 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from enum import Enum from inspect import Parameter from typing import Any, Optional, Type, Union from pydantic import BaseModel + +class Protocol(str, Enum): + """Defines how the client should choose between communication protocols.""" + + TOOLBOX = "toolbox" + MCP_v20250618 = "2025-06-18" + MCP_v20250326 = "2025-03-26" + MCP_v20241105 = "2024-11-05" + MCP = MCP_v20250618 + + @staticmethod + def get_supported_mcp_versions() -> list[str]: + """Returns a list of supported MCP protocol versions.""" + return [ + Protocol.MCP_v20250618.value, + Protocol.MCP_v20250326.value, + Protocol.MCP_v20241105.value, + ] + + __TYPE_MAP = { "string": str, "integer": int, diff --git a/packages/toolbox-core/tests/mcp_transport/test_base.py b/packages/toolbox-core/tests/mcp_transport/test_base.py new file mode 100644 index 000000000..fd64ef74a --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_base.py @@ -0,0 +1,180 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any +from unittest.mock import AsyncMock + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.transport_base import _McpHttpTransportBase +from toolbox_core.protocol import ToolSchema + + +class ConcreteTransport(_McpHttpTransportBase): + """A concrete class for testing the abstract base class.""" + + async def _initialize_session(self): + pass + + async def _send_request(self, *args, **kwargs) -> Any: + pass + + async def tools_list(self, *args, **kwargs): + pass + + async def tool_get(self, *args, **kwargs): + pass + + async def tool_invoke(self, *args, **kwargs): + pass + + +@pytest_asyncio.fixture +async def transport(mocker): + """ + A pytest fixture that creates and tears down a ConcreteTransport instance + for each test that uses it. + """ + base_url = "http://fake-server.com" + transport_instance = ConcreteTransport(base_url) + mocker.patch.object( + transport_instance, "_initialize_session", new_callable=AsyncMock + ) + mocker.patch.object(transport_instance, "_send_request", new_callable=AsyncMock) + + yield transport_instance + await transport_instance.close() + + +class TestMcpHttpTransportBase: + @pytest.mark.asyncio + async def test_initialization_properties(self, transport): + """Test constructor properties are set correctly.""" + assert transport.base_url == "http://fake-server.com/mcp/" + assert transport._manage_session is True + assert transport._session is not None + + @pytest.mark.asyncio + async def test_ensure_initialized_calls_initialize(self, transport, mocker): + """Test that _ensure_initialized calls _initialize_session.""" + mocker.patch.object(transport, "_initialize_session", new_callable=AsyncMock) + await transport._ensure_initialized() + transport._initialize_session.assert_called_once() + + @pytest.mark.asyncio + async def test_initialization_with_external_session(self): + """Test that an external session is used and not managed.""" + mock_session = AsyncMock(spec=ClientSession) + transport = ConcreteTransport("http://fake-server.com", session=mock_session) + assert transport._manage_session is False + assert transport._session is mock_session + await transport.close() + + @pytest.mark.asyncio + async def test_ensure_initialized_is_called(self, transport): + """Test that _ensure_initialized calls _initialize_session.""" + await transport._ensure_initialized() + transport._initialize_session.assert_called_once() + + @pytest.mark.asyncio + async def test_initialization_is_only_run_once(self, transport): + """Test the lock ensures initialization only happens once with concurrent calls.""" + init_started = asyncio.Event() + + async def slow_init(): + init_started.set() + await asyncio.sleep(0.01) + + transport._initialize_session.side_effect = slow_init + + task1 = asyncio.create_task(transport._ensure_initialized()) + await init_started.wait() + task2 = asyncio.create_task(transport._ensure_initialized()) + await asyncio.gather(task1, task2) + + transport._initialize_session.assert_called_once() + + def test_convert_tool_schema_valid(self, transport): + """Test converting a valid MCP tool schema.""" + raw_tool = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "arg1": {"type": "string", "description": "Argument 1"}, + "arg2": {"type": "integer"}, + }, + "required": ["arg1"], + }, + } + + schema = transport._convert_tool_schema(raw_tool) + + assert isinstance(schema, ToolSchema) + assert schema.description == "A test tool" + assert len(schema.parameters) == 2 + + p1 = next(p for p in schema.parameters if p.name == "arg1") + assert p1.type == "string" + assert p1.description == "Argument 1" + assert p1.required is True + + p2 = next(p for p in schema.parameters if p.name == "arg2") + assert p2.type == "integer" + assert p2.required is False + + def test_convert_tool_schema_complex_types(self, transport): + """Test converting schema with array and object types.""" + raw_tool = { + "name": "complex_tool", + "inputSchema": { + "type": "object", + "properties": { + "list_param": {"type": "array", "items": {"type": "string"}}, + "obj_param": { + "type": "object", + "additionalProperties": {"type": "integer"}, + }, + }, + }, + } + + schema = transport._convert_tool_schema(raw_tool) + p_list = next(p for p in schema.parameters if p.name == "list_param") + assert p_list.type == "array" + + p_obj = next(p for p in schema.parameters if p.name == "obj_param") + assert p_obj.type == "object" + assert p_obj.additionalProperties.type == "integer" + + @pytest.mark.asyncio + async def test_close_managed_session(self, mocker): + mock_close = mocker.patch("aiohttp.ClientSession.close", new_callable=AsyncMock) + transport = ConcreteTransport("http://fake-server.com") + # Mock the init task so close() tries to await it + transport._init_task = asyncio.create_task(asyncio.sleep(0)) + await transport.close() + mock_close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_unmanaged_session(self): + mock_session = AsyncMock(spec=ClientSession) + transport = ConcreteTransport("http://fake-server.com", session=mock_session) + transport._init_task = asyncio.create_task(asyncio.sleep(0)) + await transport.close() + mock_session.close.assert_not_called() diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20241105.py b/packages/toolbox-core/tests/mcp_transport/test_v20241105.py new file mode 100644 index 000000000..832dcd4c3 --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_v20241105.py @@ -0,0 +1,274 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.v20241105 import types +from toolbox_core.mcp_transport.v20241105.mcp import McpHttpTransportV20241105 +from toolbox_core.protocol import ManifestSchema, Protocol + + +def create_fake_tools_list_result(): + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Gets the weather.", + inputSchema={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + ] + ) + + +@pytest_asyncio.fixture +async def transport(): + mock_session = AsyncMock(spec=ClientSession) + transport = McpHttpTransportV20241105( + "http://fake-server.com", session=mock_session, protocol=Protocol.MCP_v20241105 + ) + yield transport + await transport.close() + + +@pytest.mark.asyncio +class TestMcpHttpTransportV20241105: + + # --- Request Sending Tests --- + + async def test_send_request_success(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + + mock_content = Mock() + mock_content.at_eof.return_value = False + mock_response.content = mock_content + + mock_response.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "result": {"foo": "bar"}, + } + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + foo: str + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult(foo="bar") + + async def test_send_request_api_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = False + mock_response.status = 500 + mock_response.text.return_value = "Error" + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="API request failed with status 500"): + await transport._send_request("url", TestRequest()) + + async def test_send_request_mcp_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + + mock_response.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": {"code": -32601, "message": "Method not found"}, + } + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="MCP request failed"): + await transport._send_request("url", TestRequest()) + + async def test_send_notification(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 204 + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestNotification(types.MCPNotification): + method: str = "notifications/test" + params: dict = {} + + await transport._send_request("url", TestNotification()) + + call_kwargs = transport._session.post.call_args.kwargs + payload = call_kwargs["json"] + assert "id" not in payload + assert payload["method"] == "notifications/test" + + # --- Initialization Tests --- + + @patch("toolbox_core.mcp_transport.v20241105.mcp.version") + async def test_initialize_session_success(self, mock_version, transport, mocker): + mock_version.__version__ = "1.2.3" + mock_send = mocker.patch.object( + transport, "_send_request", new_callable=AsyncMock + ) + + mock_send.side_effect = [ + types.InitializeResult( + protocolVersion="2024-11-05", + capabilities=types.ServerCapabilities(tools={"listChanged": False}), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + None, + ] + + await transport._initialize_session() + + assert transport._server_version == "1.0" + assert mock_send.call_count == 2 + init_call = mock_send.call_args_list[0] + init_call = mock_send.call_args_list[0] + assert isinstance(init_call.kwargs["request"], types.InitializeRequest) + assert init_call.kwargs["request"].params.protocolVersion == "2024-11-05" + + async def test_initialize_session_protocol_mismatch(self, transport, mocker): + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.InitializeResult( + protocolVersion="2099-01-01", + capabilities=types.ServerCapabilities(tools={"listChanged": True}), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + ) + + with pytest.raises(RuntimeError, match="MCP version mismatch"): + await transport._initialize_session() + + async def test_initialize_session_missing_tools_capability(self, transport, mocker): + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.InitializeResult( + protocolVersion="2024-11-05", + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + ) + + with pytest.raises( + RuntimeError, match="Server does not support the 'tools' capability" + ): + await transport._initialize_session() + + # --- Tool Management Tests --- + + async def test_tools_list_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + + manifest = await transport.tools_list() + assert isinstance(manifest, ManifestSchema) + assert "get_weather" in manifest.tools + + async def test_tools_list_with_toolset_name(self, transport, mocker): + """Test listing tools with a specific toolset name updates the URL.""" + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0.0" + + manifest = await transport.tools_list(toolset_name="custom_toolset") + + assert isinstance(manifest, ManifestSchema) + # Verify the toolset name was appended to the base URL + # Verify the toolset name was appended to the base URL + expected_url = transport.base_url + "custom_toolset" + + call_args = transport._send_request.call_args + assert call_args.kwargs["url"] == expected_url + assert isinstance(call_args.kwargs["request"], types.ListToolsRequest) + assert call_args.kwargs["headers"] is None + + async def test_tool_invoke_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text="Result")] + ), + ) + + result = await transport.tool_invoke("tool", {}, {}) + assert result == "Result" + + async def test_tool_get_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + + manifest = await transport.tool_get("get_weather") + assert "get_weather" in manifest.tools + assert len(manifest.tools) == 1 diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20250326.py b/packages/toolbox-core/tests/mcp_transport/test_v20250326.py new file mode 100644 index 000000000..c23f51d4f --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_v20250326.py @@ -0,0 +1,267 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.v20250326 import types +from toolbox_core.mcp_transport.v20250326.mcp import McpHttpTransportV20250326 +from toolbox_core.protocol import ManifestSchema, Protocol + + +def create_fake_tools_list_result(): + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Gets the weather.", + inputSchema={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + ] + ) + + +@pytest_asyncio.fixture +async def transport(): + mock_session = AsyncMock(spec=ClientSession) + transport = McpHttpTransportV20250326( + "http://fake-server.com", session=mock_session, protocol=Protocol.MCP_v20250326 + ) + yield transport + await transport.close() + + +@pytest.mark.asyncio +class TestMcpHttpTransportV20250326: + # --- Request Sending Tests (Standard + Session ID) --- + + async def test_send_request_success(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult() + + async def test_send_request_with_session_id(self, transport): + """Test that the session ID is injected into params.""" + transport._session_id = "test-session-id" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + await transport._send_request("url", TestRequest(params={"param": "value"})) + + call_args = transport._session.post.call_args + sent_params = call_args.kwargs["json"]["params"] + assert sent_params["Mcp-Session-Id"] == "test-session-id" + assert sent_params["param"] == "value" + + async def test_send_request_api_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = False + mock_response.status = 500 + mock_response.text.return_value = "Error" + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="API request failed"): + await transport._send_request("url", TestRequest()) + + async def test_send_request_mcp_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": {"code": -32601, "message": "Error"}, + } + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="MCP request failed"): + await transport._send_request("url", TestRequest()) + + async def test_send_notification(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 204 + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestNotification(types.MCPNotification): + method: str = "notifications/test" + params: dict = {} + + await transport._send_request("url", TestNotification()) + payload = transport._session.post.call_args.kwargs["json"] + assert "id" not in payload + + # --- Initialization Tests (Session ID Required) --- + + @patch("toolbox_core.mcp_transport.v20250326.mcp.version") + async def test_initialize_session_success(self, mock_version, transport, mocker): + mock_version.__version__ = "1.2.3" + mock_send = mocker.patch.object( + transport, "_send_request", new_callable=AsyncMock + ) + + mock_send.side_effect = [ + types.InitializeResult.model_validate( + { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": True}}, + "serverInfo": {"name": "test", "version": "1.0"}, + "Mcp-Session-Id": "sess-123", + } + ), + None, + ] + + await transport._initialize_session() + assert transport._session_id == "sess-123" + + async def test_initialize_session_missing_session_id(self, transport, mocker): + """Specific test for 2025-03-26: Error if session ID is missing.""" + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.InitializeResult( + protocolVersion="2025-03-26", + capabilities=types.ServerCapabilities(tools={"listChanged": True}), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + ) + # Mock close since it will be called on failure + mocker.patch.object(transport, "close", new_callable=AsyncMock) + + with pytest.raises( + RuntimeError, match="Server did not return a Mcp-Session-Id" + ): + await transport._initialize_session() + + # --- Tool Management Tests --- + + async def test_tools_list_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + manifest = await transport.tools_list() + assert isinstance(manifest, ManifestSchema) + + async def test_tools_list_with_toolset_name(self, transport, mocker): + """Test listing tools with a specific toolset name updates the URL.""" + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0.0" + + manifest = await transport.tools_list(toolset_name="custom_toolset") + + assert isinstance(manifest, ManifestSchema) + expected_url = transport.base_url + "custom_toolset" + + call_args = transport._send_request.call_args + assert call_args.kwargs["url"] == expected_url + assert isinstance(call_args.kwargs["request"], types.ListToolsRequest) + assert call_args.kwargs["headers"] is None + + async def test_tool_invoke_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text="Result")] + ), + ) + result = await transport.tool_invoke("tool", {}, {}) + assert result == "Result" + + async def test_tool_get_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + manifest = await transport.tool_get("get_weather") + assert "get_weather" in manifest.tools diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20250618.py b/packages/toolbox-core/tests/mcp_transport/test_v20250618.py new file mode 100644 index 000000000..e8ab76f94 --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_v20250618.py @@ -0,0 +1,275 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.v20250618 import types +from toolbox_core.mcp_transport.v20250618.mcp import McpHttpTransportV20250618 +from toolbox_core.protocol import ManifestSchema, Protocol + + +def create_fake_tools_list_result(): + return types.ListToolsResult( + tools=[ + types.Tool( + name="get_weather", + description="Gets the weather.", + inputSchema={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + ] + ) + + +@pytest_asyncio.fixture +async def transport(): + mock_session = AsyncMock(spec=ClientSession) + transport = McpHttpTransportV20250618( + "http://fake-server.com", session=mock_session, protocol=Protocol.MCP_v20250618 + ) + yield transport + await transport.close() + + +@pytest.mark.asyncio +class TestMcpHttpTransportV20250618: + + # --- Request Sending Tests (Standard + Header) --- + + async def test_send_request_success(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult() + + async def test_send_request_adds_protocol_header(self, transport): + """Test that the MCP-Protocol-Version header is added.""" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + await transport._send_request("url", TestRequest()) + + call_args = transport._session.post.call_args + headers = call_args.kwargs["headers"] + assert headers["MCP-Protocol-Version"] == "2025-06-18" + + async def test_send_request_api_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = False + mock_response.status = 500 + mock_response.text.return_value = "Error" + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="API request failed"): + await transport._send_request("url", TestRequest()) + + async def test_send_request_mcp_error(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": {"code": -32601, "message": "Error"}, + } + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(RuntimeError, match="MCP request failed"): + await transport._send_request("url", TestRequest()) + + async def test_send_notification(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 204 + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestNotification(types.MCPNotification): + method: str = "notifications/test" + params: dict = {} + + await transport._send_request("url", TestNotification()) + payload = transport._session.post.call_args.kwargs["json"] + assert "id" not in payload + + # --- Initialization Tests --- + + @patch("toolbox_core.mcp_transport.v20250618.mcp.version") + async def test_initialize_session_success(self, mock_version, transport, mocker): + mock_version.__version__ = "1.2.3" + mock_send = mocker.patch.object( + transport, "_send_request", new_callable=AsyncMock + ) + + mock_send.side_effect = [ + types.InitializeResult( + protocolVersion="2025-06-18", + capabilities=types.ServerCapabilities(tools={"listChanged": True}), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + None, + ] + + await transport._initialize_session() + assert transport._server_version == "1.0" + + async def test_initialize_session_protocol_mismatch(self, transport, mocker): + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.InitializeResult( + protocolVersion="2099-01-01", + capabilities=types.ServerCapabilities(tools={"listChanged": True}), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + ) + + with pytest.raises(RuntimeError, match="MCP version mismatch"): + await transport._initialize_session() + + async def test_initialize_session_missing_tools_capability(self, transport, mocker): + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.InitializeResult( + protocolVersion="2025-06-18", + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test", version="1.0"), + ), + ) + + with pytest.raises( + RuntimeError, match="Server does not support the 'tools' capability" + ): + await transport._initialize_session() + + # --- Tool Management Tests --- + + async def test_tools_list_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + manifest = await transport.tools_list() + assert isinstance(manifest, ManifestSchema) + + async def test_tools_list_with_toolset_name(self, transport, mocker): + """Test listing tools with a specific toolset name updates the URL.""" + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0.0" + + manifest = await transport.tools_list(toolset_name="custom_toolset") + + assert isinstance(manifest, ManifestSchema) + expected_url = transport.base_url + "custom_toolset" + + call_args = transport._send_request.call_args + assert call_args.kwargs["url"] == expected_url + assert isinstance(call_args.kwargs["request"], types.ListToolsRequest) + assert call_args.kwargs["headers"] is None + + async def test_tool_invoke_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text="Result")] + ), + ) + result = await transport.tool_invoke("tool", {}, {}) + assert result == "Result" + + async def test_tool_get_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + transport._server_version = "1.0" + manifest = await transport.tool_get("get_weather") + assert "get_weather" in manifest.tools diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py new file mode 100644 index 000000000..821eb8f6a --- /dev/null +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -0,0 +1,359 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import Parameter, signature +from typing import Any, Optional + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from toolbox_core.client import ToolboxClient +from toolbox_core.protocol import Protocol +from toolbox_core.tool import ToolboxTool + + +@pytest_asyncio.fixture( + scope="function", + params=[ + Protocol.MCP_v20250618, + Protocol.MCP_v20250326, + Protocol.MCP_v20241105, + ], +) +async def toolbox(request): + """Creates a ToolboxClient instance shared by all tests in this module.""" + toolbox = ToolboxClient("http://localhost:5000", protocol=request.param) + try: + yield toolbox + finally: + await toolbox.close() + + +@pytest_asyncio.fixture(scope="function") +async def get_n_rows_tool(toolbox: ToolboxClient) -> ToolboxTool: + """Load the 'get-n-rows' tool using the shared toolbox client.""" + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBasicE2E: + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_load_toolset_specific( + self, + toolbox: ToolboxClient, + toolset_name: str, + expected_length: int, + expected_tools: list[str], + ): + """Load a specific toolset""" + toolset = await toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + tool_names = {tool.__name__ for tool in toolset} + assert tool_names == set(expected_tools) + + async def test_load_toolset_default(self, toolbox: ToolboxClient): + """Load the default toolset, i.e. all tools.""" + toolset = await toolbox.load_toolset() + assert len(toolset) == 7 + tool_names = {tool.__name__ for tool in toolset} + expected_tools = [ + "get-row-by-content-auth", + "get-row-by-email-auth", + "get-row-by-id-auth", + "get-row-by-id", + "get-n-rows", + "search-rows", + "process-data", + ] + assert tool_names == set(expected_tools) + + async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool.""" + response = await get_n_rows_tool(num_rows="2") + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" not in response + + async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with missing params.""" + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): + await get_n_rows_tool() + + async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with wrong param type.""" + with pytest.raises( + ValidationError, + match=r"num_rows\s+Input should be a valid string\s+\[type=string_type,\s+input_value=2,\s+input_type=int\]", + ): + await get_n_rows_tool(num_rows=2) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBindParams: + async def test_bind_params( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a param to an existing tool.""" + new_tool = get_n_rows_tool.bind_params({"num_rows": "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + async def test_bind_params_callable( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a callable param to an existing tool.""" + new_tool = get_n_rows_tool.bind_params({"num_rows": lambda: "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestOptionalParams: + """ + End-to-end tests for tools with optional parameters. + """ + + async def test_tool_signature_is_correct(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with optional params.""" + tool = await toolbox.load_tool("search-rows") + sig = signature(tool) + + assert "email" in sig.parameters + assert "data" in sig.parameters + assert "id" in sig.parameters + + # The required parameter should have no default + assert sig.parameters["email"].default is Parameter.empty + assert sig.parameters["email"].annotation is str + + # The optional parameter should have a default of None + assert sig.parameters["data"].default is None + assert sig.parameters["data"].annotation is Optional[str] + + # The optional parameter should have a default of None + assert sig.parameters["id"].default is None + assert sig.parameters["id"].annotation is Optional[int] + + async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient): + """Invoke a tool providing only the required parameter.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=1) + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(TypeError, match="missing a required argument: 'email'"): + await tool(id=5, data="row5") + + async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(ValidationError, match="email"): + await tool(email=None, id=5, data="row5") + + async def test_run_tool_with_all_default_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=0, data="row2") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_all_valid_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_different_email(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different email.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="anubhavdhawan@google.com", id=3, data="row3") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_data(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row4") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_id(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=4, data="row3") + assert isinstance(response, str) + assert response == "null" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestMapParams: + """ + End-to-end tests for tools with map parameters. + """ + + async def test_tool_signature_with_map_params(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with map params.""" + tool = await toolbox.load_tool("process-data") + sig = signature(tool) + + assert "execution_context" in sig.parameters + assert sig.parameters["execution_context"].annotation == dict[str, Any] + assert sig.parameters["execution_context"].default is Parameter.empty + + assert "user_scores" in sig.parameters + assert sig.parameters["user_scores"].annotation == dict[str, int] + assert sig.parameters["user_scores"].default is Parameter.empty + + assert "feature_flags" in sig.parameters + assert sig.parameters["feature_flags"].annotation == Optional[dict[str, bool]] + assert sig.parameters["feature_flags"].default is None + + async def test_run_tool_with_map_params(self, toolbox: ToolboxClient): + """Invoke a tool with valid map parameters.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "prod", "id": 1234, "user": 1234.5}, + user_scores={"user1": 100, "user2": 200}, + feature_flags={"new_feature": True}, + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"prod","id":1234,"user":1234.5}' in response + assert '"user_scores":{"user1":100,"user2":200}' in response + assert '"feature_flags":{"new_feature":true}' in response + + async def test_run_tool_with_optional_map_param_omitted( + self, toolbox: ToolboxClient + ): + """Invoke a tool without the optional map parameter.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "dev"}, user_scores={"user3": 300} + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"dev"}' in response + assert '"user_scores":{"user3":300}' in response + assert '"feature_flags":null' in response + + async def test_run_tool_with_wrong_map_value_type(self, toolbox: ToolboxClient): + """Invoke a tool with a map parameter having the wrong value type.""" + tool = await toolbox.load_tool("process-data") + + with pytest.raises(ValidationError): + await tool( + execution_context={"env": "staging"}, + user_scores={"user4": "not-an-integer"}, + ) diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index dae95f612..b5f000670 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -18,7 +18,20 @@ import pytest -from toolbox_core.protocol import AdditionalPropertiesSchema, ParameterSchema +from toolbox_core.protocol import AdditionalPropertiesSchema, ParameterSchema, Protocol + + +def test_get_supported_mcp_versions(): + """ + Tests that get_supported_mcp_versions returns the correct list of versions, + sorted from newest to oldest. + """ + expected_versions = ["2025-06-18", "2025-03-26", "2024-11-05"] + supported_versions = Protocol.get_supported_mcp_versions() + + assert supported_versions == expected_versions + # Also verify that the non-MCP members are not included + assert "toolbox" not in supported_versions def test_parameter_schema_float():