Skip to content

Commit 0ea2ce7

Browse files
committed
feat: do not redact input on output guardrail intervention
Change the behavior to redact input and/or output only if input and/or output guardrails intervened, respectively
1 parent 8a89d91 commit 0ea2ce7

File tree

3 files changed

+29
-23
lines changed

3 files changed

+29
-23
lines changed

src/strands/models/bedrock.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import os
1010
import warnings
11-
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
11+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Tuple, Type, TypeVar, Union, cast
1212

1313
import boto3
1414
from botocore.config import Config as BotocoreConfig
@@ -518,7 +518,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
518518

519519
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
520520

521-
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
521+
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> Tuple[bool, bool]:
522522
"""Check if guardrail data contains any blocked policies.
523523
524524
Args:
@@ -530,25 +530,27 @@ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
530530
input_assessment = guardrail_data.get("inputAssessment", {})
531531
output_assessments = guardrail_data.get("outputAssessments", {})
532532

533+
blocked_input, blocked_output = False, False
534+
533535
# Check input assessments
534536
if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
535-
return True
537+
blocked_input = True
536538

537539
# Check output assessments
538540
if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
539-
return True
541+
blocked_output = True
540542

541-
return False
543+
return blocked_input, blocked_output
542544

543-
def _generate_redaction_events(self) -> list[StreamEvent]:
545+
def _generate_redaction_events(self, redact_input: bool, redact_output: bool) -> list[StreamEvent]:
544546
"""Generate redaction events based on configuration.
545547
546548
Returns:
547549
List of redaction events to yield.
548550
"""
549551
events: list[StreamEvent] = []
550552

551-
if self.config.get("guardrail_redact_input", True):
553+
if redact_input and self.config.get("guardrail_redact_input", True):
552554
logger.debug("Redacting user input due to guardrail.")
553555
events.append(
554556
{
@@ -560,7 +562,7 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
560562
}
561563
)
562564

563-
if self.config.get("guardrail_redact_output", False):
565+
if redact_output and self.config.get("guardrail_redact_output", False):
564566
logger.debug("Redacting assistant output due to guardrail.")
565567
events.append(
566568
{
@@ -669,9 +671,9 @@ def _stream(
669671
and "guardrail" in chunk["metadata"]["trace"]
670672
):
671673
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
672-
if self._has_blocked_guardrail(guardrail_data):
673-
for event in self._generate_redaction_events():
674-
callback(event)
674+
blocked_input, blocked_output = self._has_blocked_guardrail(guardrail_data)
675+
for event in self._generate_redaction_events(blocked_input, blocked_output):
676+
callback(event)
675677

676678
# Track if we see tool use events
677679
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
@@ -697,12 +699,10 @@ def _stream(
697699
for event in self._convert_non_streaming_to_streaming(response):
698700
callback(event)
699701

700-
if (
701-
"trace" in response
702-
and "guardrail" in response["trace"]
703-
and self._has_blocked_guardrail(response["trace"]["guardrail"])
704-
):
705-
for event in self._generate_redaction_events():
702+
if "trace" in response and "guardrail" in response["trace"]:
703+
guardrail_data = response["trace"]["guardrail"]
704+
blocked_input, blocked_output = self._has_blocked_guardrail(guardrail_data)
705+
for event in self._generate_redaction_events(blocked_input, blocked_output):
706706
callback(event)
707707

708708
except ClientError as e:

tests/strands/models/test_bedrock.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ async def test_stream_stream_output_guardrails(
687687

688688

689689
@pytest.mark.asyncio
690-
async def test_stream_output_guardrails_redacts_input_and_output(
690+
async def test_stream_output_guardrails_redacts_output(
691691
bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist
692692
):
693693
model.update_config(guardrail_redact_output=True)
@@ -735,7 +735,6 @@ async def test_stream_output_guardrails_redacts_input_and_output(
735735

736736
tru_chunks = await alist(response)
737737
exp_chunks = [
738-
{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
739738
{"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}},
740739
metadata_event,
741740
]
@@ -1070,7 +1069,10 @@ async def test_stream_input_guardrails(bedrock_client, alist, messages):
10701069

10711070
@pytest.mark.asyncio
10721071
async def test_stream_output_guardrails(bedrock_client, alist, messages):
1073-
"""Test stream method with streaming=False."""
1072+
"""Test stream method with streaming=False.
1073+
1074+
Output guardrail should not redact the input.
1075+
"""
10741076
bedrock_client.converse.return_value = {
10751077
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
10761078
"trace": {
@@ -1113,7 +1115,6 @@ async def test_stream_output_guardrails(bedrock_client, alist, messages):
11131115
}
11141116
}
11151117
},
1116-
{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
11171118
]
11181119
assert tru_events == exp_events
11191120

@@ -1122,7 +1123,7 @@ async def test_stream_output_guardrails(bedrock_client, alist, messages):
11221123

11231124

11241125
@pytest.mark.asyncio
1125-
async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, messages):
1126+
async def test_stream_output_guardrails_does_not_redact_input(bedrock_client, alist, messages):
11261127
"""Test stream method with streaming=False."""
11271128
bedrock_client.converse.return_value = {
11281129
"output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
@@ -1166,7 +1167,6 @@ async def test_stream_output_guardrails_redacts_output(bedrock_client, alist, me
11661167
}
11671168
}
11681169
},
1169-
{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
11701170
]
11711171
assert tru_events == exp_events
11721172

tests_integ/test_bedrock_guardrails.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_guardrail_input_intervention(boto_session, bedrock_guardrail):
105105
guardrail_id=bedrock_guardrail,
106106
guardrail_version="DRAFT",
107107
boto_session=boto_session,
108+
guardrail_redact_input_message="Redacted.",
108109
)
109110

110111
agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)
@@ -116,6 +117,7 @@ def test_guardrail_input_intervention(boto_session, bedrock_guardrail):
116117
assert str(response1).strip() == BLOCKED_INPUT
117118
assert response2.stop_reason != "guardrail_intervened"
118119
assert str(response2).strip() != BLOCKED_INPUT
120+
assert agent.messages[0]["content"][0]["text"] == "Redacted."
119121

120122

121123
@pytest.mark.parametrize("processing_mode", ["sync", "async"])
@@ -193,6 +195,10 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi
193195
assert REDACT_MESSAGE in str(response1)
194196
assert response2.stop_reason != "guardrail_intervened"
195197
assert REDACT_MESSAGE not in str(response2)
198+
# Input not redacted being an output intervention
199+
assert agent.messages[0]["content"][0]["text"] != REDACT_MESSAGE
200+
# Output correctly redacted
201+
assert agent.messages[1]["content"][0]["text"] == REDACT_MESSAGE
196202
else:
197203
cactus_returned_in_response1_blocked_by_input_guardrail = BLOCKED_INPUT in str(response2)
198204
cactus_blocked_in_response1_allows_next_response = (

0 commit comments

Comments
 (0)