diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py index d6e2ca67c..ce5708136 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/transport_base.py @@ -57,7 +57,24 @@ def base_url(self) -> str: return self._mcp_base_url def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: - """Converts a raw MCP tool dictionary into the Toolbox ToolSchema.""" + """ + Safely converts the raw tool dictionary from the server into a ToolSchema object, + robustly handling optional authentication metadata. + """ + param_auth = None + invoke_auth = [] + + if "_meta" in tool_data and isinstance(tool_data["_meta"], dict): + meta = tool_data["_meta"] + if "toolbox/authParam" in meta and isinstance( + meta["toolbox/authParam"], dict + ): + param_auth = meta["toolbox/authParam"] + if "toolbox/authInvoke" in meta and isinstance( + meta["toolbox/authInvoke"], list + ): + invoke_auth = meta["toolbox/authInvoke"] + parameters = [] input_schema = tool_data.get("inputSchema", {}) properties = input_schema.get("properties", {}) @@ -71,6 +88,11 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: ) else: additional_props = True + + if param_auth and name in param_auth: + auth_sources = param_auth[name] + else: + auth_sources = None parameters.append( ParameterSchema( name=name, @@ -78,11 +100,14 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: description=schema.get("description", ""), required=name in required, additionalProperties=additional_props, + authSources=auth_sources, ) ) return ToolSchema( - description=tool_data.get("description") or "", parameters=parameters + description=tool_data.get("description") or "", + parameters=parameters, + authRequired=invoke_auth, ) async def close(self): diff --git a/packages/toolbox-core/tests/mcp_transport/test_base.py b/packages/toolbox-core/tests/mcp_transport/test_base.py index fd64ef74a..1e1c8d036 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_base.py +++ b/packages/toolbox-core/tests/mcp_transport/test_base.py @@ -162,6 +162,34 @@ def test_convert_tool_schema_complex_types(self, transport): assert p_obj.type == "object" assert p_obj.additionalProperties.type == "integer" + def test_convert_tool_schema_with_auth_metadata(self, transport): + """Test converting tool schema with auth metadata fields.""" + raw_tool = { + "name": "auth_tool", + "description": "Tool with auth params", + "inputSchema": { + "type": "object", + "properties": { + "apiKey": {"type": "string"}, + }, + }, + "_meta": { + "toolbox/authParam": {"apiKey": ["my-auth-source"]}, + "toolbox/authInvoke": ["my-auth-invoke"], + }, + } + + schema = transport._convert_tool_schema(raw_tool) + + assert isinstance(schema, ToolSchema) + + # Check that authRequired (from toolbox/authInvoke) was populated + assert schema.authRequired == ["my-auth-invoke"] + + # Check that authSources (from toolbox/authParam) was populated on the parameter + p_api_key = next(p for p in schema.parameters if p.name == "apiKey") + assert p_api_key.authSources == ["my-auth-source"] + @pytest.mark.asyncio async def test_close_managed_session(self, mocker): mock_close = mocker.patch("aiohttp.ClientSession.close", new_callable=AsyncMock) diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index 821eb8f6a..2f8d7193d 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -139,6 +139,99 @@ async def test_bind_params_callable( assert "row4" not in response +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestAuth: + async def test_run_tool_unauth_with_auth( + self, toolbox: ToolboxClient, auth_token2: str + ): + """Tests running a tool that doesn't require auth, with auth provided.""" + + with pytest.raises( + ValueError, + match=rf"Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth", + ): + await toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) + + async def test_run_tool_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool(id="2") + + async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str): + """Tests running a tool with incorrect auth. The tool + requires a different authentication than the one provided.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) + with pytest.raises( + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", + ): + await auth_tool(id="2") + + async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) + response = await auth_tool(id="2") + assert "row2" in response + + @pytest.mark.asyncio + async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth using an async token getter.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + + async def get_token_asynchronously(): + return auth_token1 + + auth_tool = tool.add_auth_token_getters( + {"my-test-auth": get_token_asynchronously} + ) + response = await auth_tool(id="2") + assert "row2" in response + + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool() + + async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + response = await tool() + assert "row4" in response + assert "row5" in response + assert "row6" in response + + async def test_run_tool_param_auth_no_field( + self, toolbox: ToolboxClient, auth_token1: str + ): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + with pytest.raises( + Exception, + match="no field named row_data in claims", + ): + await tool() + + @pytest.mark.asyncio @pytest.mark.usefixtures("toolbox_server") class TestOptionalParams: