Skip to content

Commit d777b96

Browse files
committed
fix bugs
1 parent cc621a3 commit d777b96

File tree

1 file changed

+73
-7
lines changed

1 file changed

+73
-7
lines changed

packages/toolbox-core/src/toolbox_core/mcp_transport.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
# limitations under the License.
1414

1515
import uuid
16-
import asyncio
17-
from typing import Any, Mapping, Optional
16+
from typing import Any, Mapping, Optional, Union
1817

1918
from aiohttp import ClientSession
2019

2120
from .itransport import ITransport
22-
from .protocol import ManifestSchema, Protocol
21+
from .protocol import (
22+
AdditionalPropertiesSchema,
23+
ManifestSchema,
24+
ParameterSchema,
25+
Protocol,
26+
ToolSchema,
27+
)
2328

2429

2530
class McpHttpTransport(ITransport):
@@ -50,6 +55,49 @@ def __init__(
5055
def base_url(self) -> str:
5156
return self.__base_url
5257

58+
def _convert_tool_schema(self, tool_data: dict) -> ToolSchema:
59+
parameters = []
60+
input_schema = tool_data.get("inputSchema", {})
61+
properties = input_schema.get("properties", {})
62+
required = input_schema.get("required", [])
63+
64+
for name, schema in properties.items():
65+
additional_props_value = schema.get("additionalProperties")
66+
final_additional_properties: Union[bool, AdditionalPropertiesSchema] = True
67+
68+
if isinstance(additional_props_value, dict):
69+
final_additional_properties = AdditionalPropertiesSchema(
70+
type=additional_props_value["type"]
71+
)
72+
parameters.append(
73+
ParameterSchema(
74+
name=name,
75+
type=schema["type"],
76+
description=schema.get("description", ""),
77+
required=name in required,
78+
additionalProperties=final_additional_properties,
79+
)
80+
)
81+
82+
return ToolSchema(description=tool_data["description"], parameters=parameters)
83+
84+
async def _list_tools(
85+
self,
86+
toolset_name: Optional[str] = None,
87+
headers: Optional[Mapping[str, str]] = None,
88+
) -> Any:
89+
"""Private helper to fetch the raw tool list from the server."""
90+
# TODO: Do not use lazy initialisation
91+
if not self.__mcp_initialized:
92+
await self._initialize_session()
93+
if toolset_name:
94+
url = f"{self.__base_url}/mcp/{toolset_name}"
95+
else:
96+
url = f"{self.__base_url}/mcp/"
97+
return await self._send_request(
98+
url=url, method="tools/list", params={}, headers=headers
99+
)
100+
53101
async def tool_get(
54102
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
55103
) -> ManifestSchema:
@@ -63,6 +111,12 @@ async def tool_get(
63111
else:
64112
raise ValueError(f"Tool '{tool_name}' not found.")
65113

114+
tool_details = ManifestSchema(
115+
serverVersion=self.__server_version,
116+
tools={tool_name: tool_def},
117+
)
118+
return tool_details
119+
66120
async def tools_list(
67121
self,
68122
toolset_name: Optional[str] = None,
@@ -80,15 +134,21 @@ async def tools_list(
80134
return ManifestSchema(**result)
81135

82136
async def tool_invoke(
83-
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
84-
) -> dict:
137+
self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]]
138+
) -> str:
85139
"""Invokes a specific tool on the server using the MCP protocol."""
86140
url = f"{self.__base_url}/mcp/"
87141
params = {"name": tool_name, "arguments": arguments}
88142
result = await self._send_request(
89143
url=url, method="tools/call", params=params, headers=headers
90144
)
91-
return result
145+
all_content = result.get("content", result)
146+
content_str = "".join(
147+
content.get("text", "")
148+
for content in all_content
149+
if isinstance(content, dict)
150+
)
151+
return content_str or "null"
92152

93153
async def close(self):
94154
if self.__manage_session and not self.__session.closed:
@@ -109,7 +169,13 @@ async def _send_request(
109169
"params": params,
110170
"id": request_id,
111171
}
112-
async with self.__session.post(url, json=payload, headers=headers) as response:
172+
173+
if not method.startswith("notifications/"):
174+
payload["id"] = str(uuid.uuid4())
175+
176+
async with self.__session.post(
177+
url, json=payload, headers=req_headers
178+
) as response:
113179
if not response.ok:
114180
error_text = await response.text()
115181
raise RuntimeError(

0 commit comments

Comments
 (0)