1818_singleton_lock = Lock ()
1919
2020
21+ class MuxMatchingError (Exception ):
22+ """An exception for muxing matching errors."""
23+
24+ pass
25+
26+
2127async def get_muxing_rules_registry ():
2228 """Returns a singleton instance of the muxing rules registry."""
2329
@@ -48,9 +54,9 @@ def __init__(
4854class MuxingRuleMatcher (ABC ):
4955 """Base class for matching muxing rules."""
5056
51- def __init__ (self , route : ModelRoute , matcher_blob : str ):
57+ def __init__ (self , route : ModelRoute , mux_rule : mux_models . MuxRule ):
5258 self ._route = route
53- self ._matcher_blob = matcher_blob
59+ self ._mux_rule = mux_rule
5460
5561 @abstractmethod
5662 def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
@@ -67,18 +73,20 @@ class MuxingMatcherFactory:
6773 """Factory for creating muxing matchers."""
6874
6975 @staticmethod
70- def create (mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
76+ def create (db_mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
7177 """Create a muxing matcher for the given endpoint and model."""
7278
7379 factory : Dict [mux_models .MuxMatcherType , MuxingRuleMatcher ] = {
7480 mux_models .MuxMatcherType .catch_all : CatchAllMuxingRuleMatcher ,
7581 mux_models .MuxMatcherType .filename_match : FileMuxingRuleMatcher ,
76- mux_models .MuxMatcherType .request_type_match : RequestTypeMuxingRuleMatcher ,
82+ mux_models .MuxMatcherType .fim_filename : RequestTypeAndFileMuxingRuleMatcher ,
83+ mux_models .MuxMatcherType .chat_filename : RequestTypeAndFileMuxingRuleMatcher ,
7784 }
7885
7986 try :
8087 # Initialize the MuxingRuleMatcher
81- return factory [mux_rule .matcher_type ](route , mux_rule .matcher_blob )
88+ mux_rule = mux_models .MuxRule .from_db_mux_rule (db_mux_rule )
89+ return factory [mux_rule .matcher_type ](route , mux_rule )
8290 except KeyError :
8391 raise ValueError (f"Unknown matcher type: { mux_rule .matcher_type } " )
8492
@@ -103,47 +111,63 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
103111 return body_extractor .extract_unique_filenames (data )
104112 except BodyCodeSnippetExtractorError as e :
105113 logger .error (f"Error extracting filenames from request: { e } " )
106- return set ()
114+ raise MuxMatchingError ("Error extracting filenames from request" )
115+
116+ def _is_matcher_in_filenames (self , detected_client : ClientType , data : dict ) -> bool :
117+ """
118+ Check if the matcher is in the request filenames.
119+ """
120+ # Empty matcher_blob means we match everything
121+ if not self ._mux_rule .matcher :
122+ return True
123+ filenames_to_match = self ._extract_request_filenames (detected_client , data )
124+ # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
125+ # match the rule.
126+ is_filename_match = any (
127+ self ._mux_rule .matcher == filename or filename .endswith (self ._mux_rule .matcher )
128+ for filename in filenames_to_match
129+ )
130+ return is_filename_match
107131
108132 def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
109133 """
110- Retun True if there is a filename in the request that matches the matcher_blob.
111- The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
134+ Return True if the matcher is in one of the request filenames.
112135 """
113- # If there is no matcher_blob, we don't match
114- if not self ._matcher_blob :
115- return False
116- filenames_to_match = self ._extract_request_filenames (
136+ is_rule_matched = self ._is_matcher_in_filenames (
117137 thing_to_match .client_type , thing_to_match .body
118138 )
119- is_filename_match = any (self ._matcher_blob in filename for filename in filenames_to_match )
120- if is_filename_match :
121- logger .info (
122- "Filename rule matched" , filenames = filenames_to_match , matcher = self ._matcher_blob
123- )
124- return is_filename_match
139+ if is_rule_matched :
140+ logger .info ("Filename rule matched" , matcher = self ._mux_rule .matcher )
141+ return is_rule_matched
125142
126143
127- class RequestTypeMuxingRuleMatcher (MuxingRuleMatcher ):
128- """A catch all muxing rule matcher."""
144+ class RequestTypeAndFileMuxingRuleMatcher (FileMuxingRuleMatcher ):
145+ """A request type and file muxing rule matcher."""
146+
147+ def _is_request_type_match (self , is_fim_request : bool ) -> bool :
148+ """
149+ Check if the request type matches the MuxMatcherType.
150+ """
151+ incoming_request_type = "fim_filename" if is_fim_request else "chat_filename"
152+ if incoming_request_type == self ._mux_rule .matcher_type :
153+ return True
154+ return False
129155
130156 def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
131157 """
132- Return True if the request type matches the matcher_blob.
133- The matcher_blob is either "fim" or "chat" .
158+ Return True if the matcher is in one of the request filenames and
159+ if the request type matches the MuxMatcherType .
134160 """
135- # If there is no matcher_blob, we don't match
136- if not self ._matcher_blob :
137- return False
138- incoming_request_type = "fim" if thing_to_match .is_fim_request else "chat"
139- is_request_type_match = self ._matcher_blob == incoming_request_type
140- if is_request_type_match :
161+ is_rule_matched = self ._is_matcher_in_filenames (
162+ thing_to_match .client_type , thing_to_match .body
163+ ) and self ._is_request_type_match (thing_to_match .is_fim_request )
164+ if is_rule_matched :
141165 logger .info (
142- "Request type rule matched" ,
143- matcher = self ._matcher_blob ,
144- request_type = incoming_request_type ,
166+ "Request type and rule matched" ,
167+ matcher = self ._mux_rule . matcher ,
168+ is_fim_request = thing_to_match . is_fim_request ,
145169 )
146- return is_request_type_match
170+ return is_rule_matched
147171
148172
149173class MuxingRulesinWorkspaces :
0 commit comments