Skip to content

Commit cb97f35

Browse files
authored
feat: Add mcp transport protocol (#345)
* add basic code * fixes * test fix * new unit tests * rename ToolboxTransport * add py3.9 support * fix langchain tool tests * test fix * lint * fix tests * move manage session into transport * move warning to diff file * avoid code duplication * fix tests * lint * remove redundant tests * make invoke method return str * lint * fix return type * small refactor * refactor: remove transport logic from client tests * try * version negotiation * small changes * lint * fix endpoint * add some todos * lint * initialise in init * lint * add support for 'Mcp-session-id' * lint * add todo * add mcp protocol version to the latest protocol * add test coverage * small fix * small fix * small fix * thread fixes * try * add tests * lint * change small * nit * small debugging * add todos * small bug fixes * add todo * remove id field from notifications * refactor * preprocess tools with empty params * fix types * fix bugs * better error log * small cleanup * handle notifications * fix unit tests * lint * decouple client from transport * lint * use toolbox protocol for e2e tests * add e2e tests for mcp * lint * remove mcp as default protocol * remove auth tests from mcp * remove redundant lines * remove redundant lines * lint * revert some changes * initialise session in a better way * small fix * added more test cov * lint * rename private method * Made methods private * lint * rename base url * resolve comment * better readability * fix tests * lint * fix tests * lint * refactor mcp versions * lint * added test coverage * refactor mcp * lint * improve cov * lint * removed process id * Update class name * remove mcp latest * rename mcp.py * have a single method for session init * lint * better type checks for v20241105 * Revert "better type checks for v20241105" This reverts commit bc6da15. * update type checking * lint * clean file * refactor files * refactor all versions * fix mypy errors * refactor properly * lint * run mcp e2e tests on all versions
1 parent e810ec3 commit cb97f35

File tree

16 files changed

+2573
-4
lines changed

16 files changed

+2573
-4
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from deprecated import deprecated
2121

2222
from .itransport import ITransport
23-
from .protocol import ToolSchema
23+
from .mcp_transport import (
24+
McpHttpTransportV20241105,
25+
McpHttpTransportV20250326,
26+
McpHttpTransportV20250618,
27+
)
28+
from .protocol import Protocol, ToolSchema
2429
from .tool import ToolboxTool
2530
from .toolbox_transport import ToolboxTransport
2631
from .utils import identify_auth_requirements, resolve_value
@@ -44,6 +49,7 @@ def __init__(
4449
client_headers: Optional[
4550
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]
4651
] = None,
52+
protocol: Protocol = Protocol.TOOLBOX,
4753
):
4854
"""
4955
Initializes the ToolboxClient.
@@ -54,8 +60,21 @@ def __init__(
5460
If None (default), a new session is created internally. Note that
5561
if a session is provided, its lifecycle (including closing)
5662
should typically be managed externally.
57-
client_headers: Headers to include in each request sent through this client.
63+
client_headers: Headers to include in each request sent through this
64+
client.
65+
protocol: The communication protocol to use.
5866
"""
67+
if protocol == Protocol.TOOLBOX:
68+
self.__transport = ToolboxTransport(url, session)
69+
elif protocol in Protocol.get_supported_mcp_versions():
70+
if protocol == Protocol.MCP_v20250618:
71+
self.__transport = McpHttpTransportV20250618(url, session, protocol)
72+
elif protocol == Protocol.MCP_v20250326:
73+
self.__transport = McpHttpTransportV20250326(url, session, protocol)
74+
elif protocol == Protocol.MCP_v20241105:
75+
self.__transport = McpHttpTransportV20241105(url, session, protocol)
76+
else:
77+
raise ValueError(f"Unsupported MCP protocol version: {protocol}")
5978

6079
self.__transport = ToolboxTransport(url, session)
6180
self.__client_headers = client_headers if client_headers is not None else {}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .v20241105.mcp import McpHttpTransportV20241105
16+
from .v20250326.mcp import McpHttpTransportV20250326
17+
from .v20250618.mcp import McpHttpTransportV20250618
18+
19+
__all__ = [
20+
"McpHttpTransportV20241105",
21+
"McpHttpTransportV20250326",
22+
"McpHttpTransportV20250618",
23+
]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from abc import ABC, abstractmethod
17+
from typing import Optional
18+
19+
from aiohttp import ClientSession
20+
21+
from ..itransport import ITransport
22+
from ..protocol import (
23+
AdditionalPropertiesSchema,
24+
ParameterSchema,
25+
Protocol,
26+
ToolSchema,
27+
)
28+
29+
30+
class _McpHttpTransportBase(ITransport, ABC):
31+
"""Base transport for MCP protocols."""
32+
33+
def __init__(
34+
self,
35+
base_url: str,
36+
session: Optional[ClientSession] = None,
37+
protocol: Protocol = Protocol.MCP,
38+
):
39+
self._mcp_base_url = f"{base_url}/mcp/"
40+
self._protocol_version = protocol.value
41+
self._server_version: Optional[str] = None
42+
43+
self._manage_session = session is None
44+
self._session = session or ClientSession()
45+
self._init_lock = asyncio.Lock()
46+
self._init_task: Optional[asyncio.Task] = None
47+
48+
async def _ensure_initialized(self):
49+
"""Ensures the session is initialized before making requests."""
50+
async with self._init_lock:
51+
if self._init_task is None:
52+
self._init_task = asyncio.create_task(self._initialize_session())
53+
await self._init_task
54+
55+
@property
56+
def base_url(self) -> str:
57+
return self._mcp_base_url
58+
59+
def _convert_tool_schema(self, tool_data: dict) -> ToolSchema:
60+
"""Converts a raw MCP tool dictionary into the Toolbox ToolSchema."""
61+
parameters = []
62+
input_schema = tool_data.get("inputSchema", {})
63+
properties = input_schema.get("properties", {})
64+
required = input_schema.get("required", [])
65+
66+
for name, schema in properties.items():
67+
additional_props = schema.get("additionalProperties")
68+
if isinstance(additional_props, dict):
69+
additional_props = AdditionalPropertiesSchema(
70+
type=additional_props["type"]
71+
)
72+
else:
73+
additional_props = True
74+
parameters.append(
75+
ParameterSchema(
76+
name=name,
77+
type=schema["type"],
78+
description=schema.get("description", ""),
79+
required=name in required,
80+
additionalProperties=additional_props,
81+
)
82+
)
83+
84+
return ToolSchema(
85+
description=tool_data.get("description") or "", parameters=parameters
86+
)
87+
88+
async def close(self):
89+
async with self._init_lock:
90+
if self._init_task:
91+
try:
92+
await self._init_task
93+
except Exception:
94+
# If initialization failed, we can still try to close.
95+
pass
96+
if self._manage_session and self._session and not self._session.closed:
97+
await self._session.close()
98+
99+
@abstractmethod
100+
async def _initialize_session(self):
101+
"""Initializes the MCP session."""
102+
pass
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Mapping, Optional, TypeVar
16+
17+
from pydantic import BaseModel
18+
19+
from ... import version
20+
from ...protocol import ManifestSchema
21+
from ..transport_base import _McpHttpTransportBase
22+
from . import types
23+
24+
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
25+
26+
27+
class McpHttpTransportV20241105(_McpHttpTransportBase):
28+
"""Transport for the MCP v2024-11-05 protocol."""
29+
30+
async def _send_request(
31+
self,
32+
url: str,
33+
request: types.MCPRequest[ReceiveResultT] | types.MCPNotification,
34+
headers: Optional[Mapping[str, str]] = None,
35+
) -> ReceiveResultT | None:
36+
"""Sends a JSON-RPC request to the MCP server."""
37+
params = (
38+
request.params.model_dump(mode="json", exclude_none=True)
39+
if isinstance(request.params, BaseModel)
40+
else request.params
41+
)
42+
rpc_msg: BaseModel
43+
if isinstance(request, types.MCPNotification):
44+
rpc_msg = types.JSONRPCNotification(method=request.method, params=params)
45+
else:
46+
rpc_msg = types.JSONRPCRequest(method=request.method, params=params)
47+
48+
payload = rpc_msg.model_dump(mode="json", exclude_none=True)
49+
50+
async with self._session.post(
51+
url, json=payload, headers=dict(headers or {})
52+
) as response:
53+
if not response.ok:
54+
error_text = await response.text()
55+
raise RuntimeError(
56+
f"API request failed with status {response.status} "
57+
f"({response.reason}). Server response: {error_text}"
58+
)
59+
60+
if response.status == 204 or response.content.at_eof():
61+
return None
62+
63+
json_resp = await response.json()
64+
65+
# Check for JSON-RPC Error
66+
if "error" in json_resp:
67+
try:
68+
err = types.JSONRPCError.model_validate(json_resp).error
69+
raise RuntimeError(
70+
f"MCP request failed with code {err.code}: {err.message}"
71+
)
72+
except Exception:
73+
raise RuntimeError(f"MCP request failed: {json_resp.get('error')}")
74+
75+
# Parse Result
76+
if isinstance(request, types.MCPRequest):
77+
try:
78+
rpc_resp = types.JSONRPCResponse.model_validate(json_resp)
79+
return request.get_result_model().model_validate(rpc_resp.result)
80+
except Exception as e:
81+
raise RuntimeError(f"Failed to parse JSON-RPC response: {e}")
82+
return None
83+
84+
async def _initialize_session(self):
85+
"""Initializes the MCP session."""
86+
params = types.InitializeRequestParams(
87+
protocolVersion=self._protocol_version,
88+
capabilities=types.ClientCapabilities(),
89+
clientInfo=types.Implementation(
90+
name="toolbox-python-sdk", version=version.__version__
91+
),
92+
)
93+
94+
result = await self._send_request(
95+
url=self._mcp_base_url, request=types.InitializeRequest(params=params)
96+
)
97+
98+
self._server_version = result.serverInfo.version
99+
if result.protocolVersion != self._protocol_version:
100+
raise RuntimeError(
101+
f"MCP version mismatch: client does not support server version {result.protocolVersion}"
102+
)
103+
if not result.capabilities.tools:
104+
if self._manage_session:
105+
await self.close()
106+
raise RuntimeError("Server does not support the 'tools' capability.")
107+
108+
await self._send_request(
109+
url=self._mcp_base_url, request=types.InitializedNotification()
110+
)
111+
112+
async def tools_list(
113+
self,
114+
toolset_name: Optional[str] = None,
115+
headers: Optional[Mapping[str, str]] = None,
116+
) -> ManifestSchema:
117+
"""Lists available tools from the server using the MCP protocol."""
118+
await self._ensure_initialized()
119+
120+
url = self._mcp_base_url + (toolset_name if toolset_name else "")
121+
result = await self._send_request(
122+
url=url, request=types.ListToolsRequest(), headers=headers
123+
)
124+
if result is None:
125+
raise RuntimeError("Failed to list tools: No response from server.")
126+
127+
tools_map = {
128+
t.name: self._convert_tool_schema(t.model_dump(mode="json", by_alias=True))
129+
for t in result.tools
130+
}
131+
if self._server_version is None:
132+
raise RuntimeError("Server version not available.")
133+
134+
return ManifestSchema(serverVersion=self._server_version, tools=tools_map)
135+
136+
async def tool_get(
137+
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
138+
) -> ManifestSchema:
139+
"""Gets a single tool from the server by listing all and filtering."""
140+
manifest = await self.tools_list(headers=headers)
141+
142+
if tool_name not in manifest.tools:
143+
raise ValueError(f"Tool '{tool_name}' not found.")
144+
145+
return ManifestSchema(
146+
serverVersion=manifest.serverVersion,
147+
tools={tool_name: manifest.tools[tool_name]},
148+
)
149+
150+
async def tool_invoke(
151+
self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]]
152+
) -> str:
153+
"""Invokes a specific tool on the server using the MCP protocol."""
154+
await self._ensure_initialized()
155+
156+
result = await self._send_request(
157+
url=self._mcp_base_url,
158+
request=types.CallToolRequest(
159+
params=types.CallToolRequestParams(name=tool_name, arguments=arguments)
160+
),
161+
headers=headers,
162+
)
163+
if result is None:
164+
raise RuntimeError(
165+
f"Failed to invoke tool '{tool_name}': No response from server."
166+
)
167+
168+
return "".join(c.text for c in result.content if c.type == "text") or "null"

0 commit comments

Comments
 (0)