@@ -54,6 +54,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5454 TypeVar ,
5555 Union ,
5656 cast ,
57+ get_args ,
58+ get_origin ,
5759 get_type_hints ,
5860 overload ,
5961)
@@ -103,6 +105,35 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
103105 doc_str = inspect .getdoc (func ) or ""
104106 self .doc = docstring_parser .parse (doc_str )
105107
108+ def _contains_tool_context (tp : Any ) -> bool :
109+ """Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext."""
110+ if tp is None :
111+ return False
112+ origin = get_origin (tp )
113+ if origin is Union :
114+ return any (_contains_tool_context (a ) for a in get_args (tp ))
115+ # Handle direct ToolContext type
116+ return tp is ToolContext
117+
118+ for param in self .signature .parameters .values ():
119+ # Prefer resolved type hints (handles forward refs); fall back to annotation
120+ ann = self .type_hints .get (param .name , param .annotation )
121+ if ann is inspect ._empty :
122+ continue
123+
124+ if _contains_tool_context (ann ):
125+ # If decorator didn't opt-in to context injection, complain
126+ if self ._context_param is None :
127+ raise TypeError (
128+ f"Parameter '{ param .name } ' is of type 'ToolContext' but '@tool(context=True)' is missing."
129+ )
130+ # If decorator specified a different param name, complain
131+ if param .name != self ._context_param :
132+ raise TypeError (
133+ f"Parameter '{ param .name } ' is of type 'ToolContext' but has the wrong name. "
134+ f"It should be named '{ self ._context_param } '."
135+ )
136+
106137 # Get parameter descriptions from parsed docstring
107138 self .param_descriptions = {
108139 param .arg_name : param .description or f"Parameter { param .arg_name } " for param in self .doc .params
0 commit comments