88import logging
99import os
1010import warnings
11- from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Type , TypeVar , Union , cast
11+ from typing import Any , AsyncGenerator , Callable , Iterable , Literal , Optional , Tuple , Type , TypeVar , Union , cast
1212
1313import boto3
1414from botocore .config import Config as BotocoreConfig
@@ -518,7 +518,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
518518
519519 raise TypeError (f"content_type=<{ next (iter (content ))} > | unsupported type" )
520520
521- def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> bool :
521+ def _has_blocked_guardrail (self , guardrail_data : dict [str , Any ]) -> Tuple [ bool , bool ] :
522522 """Check if guardrail data contains any blocked policies.
523523
524524 Args:
@@ -530,25 +530,27 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
530530 input_assessment = guardrail_data .get ("inputAssessment" , {})
531531 output_assessments = guardrail_data .get ("outputAssessments" , {})
532532
533+ blocked_input , blocked_output = False , False
534+
533535 # Check input assessments
534536 if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in input_assessment .values ()):
535- return True
537+ blocked_input = True
536538
537539 # Check output assessments
538540 if any (self ._find_detected_and_blocked_policy (assessment ) for assessment in output_assessments .values ()):
539- return True
541+ blocked_output = True
540542
541- return False
543+ return blocked_input , blocked_output
542544
543- def _generate_redaction_events (self ) -> list [StreamEvent ]:
545+ def _generate_redaction_events (self , redact_input : bool , redact_output : bool ) -> list [StreamEvent ]:
544546 """Generate redaction events based on configuration.
545547
546548 Returns:
547549 List of redaction events to yield.
548550 """
549551 events : list [StreamEvent ] = []
550552
551- if self .config .get ("guardrail_redact_input" , True ):
553+ if redact_input and self .config .get ("guardrail_redact_input" , True ):
552554 logger .debug ("Redacting user input due to guardrail." )
553555 events .append (
554556 {
@@ -560,7 +562,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
560562 }
561563 )
562564
563- if self .config .get ("guardrail_redact_output" , False ):
565+ if redact_output and self .config .get ("guardrail_redact_output" , False ):
564566 logger .debug ("Redacting assistant output due to guardrail." )
565567 events .append (
566568 {
@@ -669,9 +671,9 @@ def _stream(
669671 and "guardrail" in chunk ["metadata" ]["trace" ]
670672 ):
671673 guardrail_data = chunk ["metadata" ]["trace" ]["guardrail" ]
672- if self ._has_blocked_guardrail (guardrail_data ):
673- for event in self ._generate_redaction_events ():
674- callback (event )
674+ blocked_input , blocked_output = self ._has_blocked_guardrail (guardrail_data )
675+ for event in self ._generate_redaction_events (blocked_input , blocked_output ):
676+ callback (event )
675677
676678 # Track if we see tool use events
677679 if "contentBlockStart" in chunk and chunk ["contentBlockStart" ].get ("start" , {}).get ("toolUse" ):
@@ -697,12 +699,10 @@ def _stream(
697699 for event in self ._convert_non_streaming_to_streaming (response ):
698700 callback (event )
699701
700- if (
701- "trace" in response
702- and "guardrail" in response ["trace" ]
703- and self ._has_blocked_guardrail (response ["trace" ]["guardrail" ])
704- ):
705- for event in self ._generate_redaction_events ():
702+ if "trace" in response and "guardrail" in response ["trace" ]:
703+ guardrail_data = response ["trace" ]["guardrail" ]
704+ blocked_input , blocked_output = self ._has_blocked_guardrail (guardrail_data )
705+ for event in self ._generate_redaction_events (blocked_input , blocked_output ):
706706 callback (event )
707707
708708 except ClientError as e :
0 commit comments