diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py new file mode 100644 index 000000000..e8b12635c --- /dev/null +++ b/src/strands/_identifier.py @@ -0,0 +1,30 @@ +"""Strands identifier utilities.""" + +import enum +import os + + +class Identifier(enum.Enum): + """Strands identifier types.""" + + AGENT = "agent" + SESSION = "session" + + +def validate(id_: str, type_: Identifier) -> str: + """Validate strands id. + + Args: + id_: Id to validate. + type_: Type of the identifier (e.g., session id, agent id, etc.) + + Returns: + Validated id. + + Raises: + ValueError: If id contains path separators. + """ + if os.path.basename(id_) != id_: + raise ValueError(f"{type_.value}_id={id_} | id cannot contain path separators") + + return id_ diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 43b5cbf8c..38e687af2 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -19,6 +19,7 @@ from opentelemetry import trace as trace_api from pydantic import BaseModel +from .. import _identifier from ..event_loop.event_loop import event_loop_cycle, run_tool from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( @@ -249,12 +250,15 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + + Raises: + ValueError: If agent id contains path separators. """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.agent_id = agent_id or _DEFAULT_AGENT_ID + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fec2f0761..9df86e17a 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -7,6 +7,7 @@ import tempfile from typing import Any, Optional, cast +from .. import _identifier from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage from .repository_session_manager import RepositorySessionManager @@ -40,8 +41,9 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: """Initialize FileSession with filesystem storage. Args: - session_id: ID for the session - storage_dir: Directory for local filesystem storage (defaults to temp dir) + session_id: ID for the session. + ID is not allowed to contain path separators (e.g., a/b). + storage_dir: Directory for local filesystem storage (defaults to temp dir). **kwargs: Additional keyword arguments for future extensibility. """ self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") @@ -50,12 +52,29 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: super().__init__(session_id=session_id, session_repository=self) def _get_session_path(self, session_id: str) -> str: - """Get session directory path.""" + """Get session directory path. + + Args: + session_id: ID for the session. + + Raises: + ValueError: If session id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}") def _get_agent_path(self, session_id: str, agent_id: str) -> str: - """Get agent directory path.""" + """Get agent directory path. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ session_path = self._get_session_path(session_id) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}") def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str: diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 0cc0a68c1..d15e6e3bd 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -8,6 +8,7 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError +from .. import _identifier from ..types.exceptions import SessionException from ..types.session import Session, SessionAgent, SessionMessage from .repository_session_manager import RepositorySessionManager @@ -51,6 +52,7 @@ def __init__( Args: session_id: ID for the session + ID is not allowed to contain path separators (e.g., a/b). bucket: S3 bucket name (required) prefix: S3 key prefix for storage organization boto_session: Optional boto3 session @@ -79,12 +81,29 @@ def __init__( super().__init__(session_id=session_id, session_repository=self) def _get_session_path(self, session_id: str) -> str: - """Get session S3 prefix.""" + """Get session S3 prefix. + + Args: + session_id: ID for the session. + + Raises: + ValueError: If session id contains a path separator. + """ + session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" def _get_agent_path(self, session_id: str, agent_id: str) -> str: - """Get agent S3 prefix.""" + """Get agent S3 prefix. + + Args: + session_id: ID for the session. + agent_id: ID for the agent. + + Raises: + ValueError: If session id or agent id contains a path separator. + """ session_path = self._get_session_path(session_id) + agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT) return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/" def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index fdce7c368..ca66ca2bf 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -250,6 +250,18 @@ def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_impo assert tru_tool_names == exp_tool_names +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test_agent__init__invalid_id(agent_id): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + Agent(agent_id=agent_id) + + def test_agent__call__( mock_model, system_prompt, diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f9fc3ba94..a89222b7e 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -53,310 +53,340 @@ def sample_message(): ) -class TestFileSessionManagerSessionOperations: - """Tests for session operations.""" - - def test_create_session(self, file_manager, sample_session): - """Test creating a session.""" - file_manager.create_session(sample_session) - - # Verify directory structure created - session_path = file_manager._get_session_path(sample_session.session_id) - assert os.path.exists(session_path) - - # Verify session file created - session_file = os.path.join(session_path, "session.json") - assert os.path.exists(session_file) - - # Verify content - with open(session_file, "r") as f: - data = json.load(f) - assert data["session_id"] == sample_session.session_id - assert data["session_type"] == sample_session.session_type - - def test_read_session(self, file_manager, sample_session): - """Test reading an existing session.""" - # Create session first - file_manager.create_session(sample_session) - - # Read it back - result = file_manager.read_session(sample_session.session_id) - - assert result.session_id == sample_session.session_id - assert result.session_type == sample_session.session_type - - def test_read_nonexistent_session(self, file_manager): - """Test reading a session that doesn't exist.""" - result = file_manager.read_session("nonexistent-session") - assert result is None - - def test_delete_session(self, file_manager, sample_session): - """Test deleting a session.""" - # Create session first - file_manager.create_session(sample_session) - session_path = file_manager._get_session_path(sample_session.session_id) - assert os.path.exists(session_path) - - # Delete session - file_manager.delete_session(sample_session.session_id) - - # Verify deletion - assert not os.path.exists(session_path) - - def test_delete_nonexistent_session(self, file_manager): - """Test deleting a session that doesn't exist.""" - # Should raise an error according to the implementation - with pytest.raises(SessionException, match="does not exist"): - file_manager.delete_session("nonexistent-session") - - -class TestFileSessionManagerAgentOperations: - """Tests for agent operations.""" - - def test_create_agent(self, file_manager, sample_session, sample_agent): - """Test creating an agent in a session.""" - # Create session first - file_manager.create_session(sample_session) - - # Create agent - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Verify directory structure - agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) - assert os.path.exists(agent_path) - - # Verify agent file - agent_file = os.path.join(agent_path, "agent.json") - assert os.path.exists(agent_file) - - # Verify content - with open(agent_file, "r") as f: - data = json.load(f) - assert data["agent_id"] == sample_agent.agent_id - assert data["state"] == sample_agent.state - - def test_read_agent(self, file_manager, sample_session, sample_agent): - """Test reading an agent from a session.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Read agent - result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) - - assert result.agent_id == sample_agent.agent_id - assert result.state == sample_agent.state - - def test_read_nonexistent_agent(self, file_manager, sample_session): - """Test reading an agent that doesn't exist.""" - result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") - assert result is None - - def test_update_agent(self, file_manager, sample_session, sample_agent): - """Test updating an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Update agent - sample_agent.state = {"updated": "value"} +def test_create_session(file_manager, sample_session): + """Test creating a session.""" + file_manager.create_session(sample_session) + + # Verify directory structure created + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Verify session file created + session_file = os.path.join(session_path, "session.json") + assert os.path.exists(session_file) + + # Verify content + with open(session_file, "r") as f: + data = json.load(f) + assert data["session_id"] == sample_session.session_id + assert data["session_type"] == sample_session.session_type + + +def test_read_session(file_manager, sample_session): + """Test reading an existing session.""" + # Create session first + file_manager.create_session(sample_session) + + # Read it back + result = file_manager.read_session(sample_session.session_id) + + assert result.session_id == sample_session.session_id + assert result.session_type == sample_session.session_type + + +def test_read_nonexistent_session(file_manager): + """Test reading a session that doesn't exist.""" + result = file_manager.read_session("nonexistent-session") + assert result is None + + +def test_delete_session(file_manager, sample_session): + """Test deleting a session.""" + # Create session first + file_manager.create_session(sample_session) + session_path = file_manager._get_session_path(sample_session.session_id) + assert os.path.exists(session_path) + + # Delete session + file_manager.delete_session(sample_session.session_id) + + # Verify deletion + assert not os.path.exists(session_path) + + +def test_delete_nonexistent_session(file_manager): + """Test deleting a session that doesn't exist.""" + # Should raise an error according to the implementation + with pytest.raises(SessionException, match="does not exist"): + file_manager.delete_session("nonexistent-session") + + +def test_create_agent(file_manager, sample_session, sample_agent): + """Test creating an agent in a session.""" + # Create session first + file_manager.create_session(sample_session) + + # Create agent + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Verify directory structure + agent_path = file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id) + assert os.path.exists(agent_path) + + # Verify agent file + agent_file = os.path.join(agent_path, "agent.json") + assert os.path.exists(agent_file) + + # Verify content + with open(agent_file, "r") as f: + data = json.load(f) + assert data["agent_id"] == sample_agent.agent_id + assert data["state"] == sample_agent.state + + +def test_read_agent(file_manager, sample_session, sample_agent): + """Test reading an agent from a session.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Read agent + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + + assert result.agent_id == sample_agent.agent_id + assert result.state == sample_agent.state + + +def test_read_nonexistent_agent(file_manager, sample_session): + """Test reading an agent that doesn't exist.""" + result = file_manager.read_agent(sample_session.session_id, "nonexistent_agent") + assert result is None + + +def test_update_agent(file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Update agent + sample_agent.state = {"updated": "value"} + file_manager.update_agent(sample_session.session_id, sample_agent) + + # Verify update + result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result.state == {"updated": "value"} + + +def test_update_nonexistent_agent(file_manager, sample_session, sample_agent): + """Test updating an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + + # Update agent + with pytest.raises(SessionException): file_manager.update_agent(sample_session.session_id, sample_agent) - # Verify update - result = file_manager.read_agent(sample_session.session_id, sample_agent.agent_id) - assert result.state == {"updated": "value"} - def test_update_nonexistent_agent(self, file_manager, sample_session, sample_agent): - """Test updating an agent.""" - # Create session and agent - file_manager.create_session(sample_session) +def test_create_message(file_manager, sample_session, sample_agent, sample_message): + """Test creating a message for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) - # Update agent - with pytest.raises(SessionException): - file_manager.update_agent(sample_session.session_id, sample_agent) + # Create message + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + # Verify message file + message_path = file_manager._get_message_path( + sample_session.session_id, sample_agent.agent_id, sample_message.message_id + ) + assert os.path.exists(message_path) + # Verify content + with open(message_path, "r") as f: + data = json.load(f) + assert data["message_id"] == sample_message.message_id -class TestFileSessionManagerMessageOperations: - """Tests for message operations.""" - def test_create_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test creating a message for an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) +def test_read_message(file_manager, sample_session, sample_agent, sample_message): + """Test reading a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - # Create message - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) + # Create multiple messages when reading + sample_message.message_id = sample_message.message_id + 1 + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - # Verify message file - message_path = file_manager._get_message_path( - sample_session.session_id, sample_agent.agent_id, sample_message.message_id + # Read message + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + + assert result.message_id == sample_message.message_id + assert result.message["role"] == sample_message.message["role"] + assert result.message["content"] == sample_message.message["content"] + + +def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent): + """Test reading a message with with a new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + + assert result is None + + +def test_read_nonexistent_message(file_manager, sample_session, sample_agent): + """Test reading a message that doesnt exist.""" + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") + assert result is None + + +def test_list_messages_all(file_manager, sample_session, sample_agent): + """Test listing all messages for an agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + messages = [] + for i in range(5): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, ) - assert os.path.exists(message_path) - - # Verify content - with open(message_path, "r") as f: - data = json.load(f) - assert data["message_id"] == sample_message.message_id - - def test_read_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test reading a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Create multiple messages when reading - sample_message.message_id = sample_message.message_id + 1 - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Read message - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) - - assert result.message_id == sample_message.message_id - assert result.message["role"] == sample_message.message["role"] - assert result.message["content"] == sample_message.message["content"] - - def test_read_messages_with_new_agent(self, file_manager, sample_session, sample_agent): - """Test reading a message with with a new agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") - - assert result is None - - def test_read_nonexistent_message(self, file_manager, sample_session, sample_agent): - """Test reading a message that doesnt exist.""" - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message") - assert result is None - - def test_list_messages_all(self, file_manager, sample_session, sample_agent): - """Test listing all messages for an agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - messages = [] - for i in range(5): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - messages.append(message) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List all messages - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - - assert len(result) == 5 - - def test_list_messages_with_limit(self, file_manager, sample_session, sample_agent): - """Test listing messages with limit.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - for i in range(10): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List with limit - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) - - assert len(result) == 3 - - def test_list_messages_with_offset(self, file_manager, sample_session, sample_agent): - """Test listing messages with offset.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - # Create multiple messages - for i in range(10): - message = SessionMessage( - message={ - "role": "user", - "content": [ContentBlock(text=f"Message {i}")], - }, - message_id=i, - ) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - - # List with offset - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) - - assert len(result) == 5 - - def test_list_messages_with_new_agent(self, file_manager, sample_session, sample_agent): - """Test listing messages with new agent.""" - # Create session and agent - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - - result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - - assert len(result) == 0 - - def test_update_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test updating a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) - file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - - # Update message - sample_message.message["content"] = [ContentBlock(text="Updated content")] - file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + messages.append(message) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) - # Verify update - result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) - assert result.message["content"][0]["text"] == "Updated content" + # List all messages + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) - def test_update_nonexistent_message(self, file_manager, sample_session, sample_agent, sample_message): - """Test updating a message.""" - # Create session, agent, and message - file_manager.create_session(sample_session) - file_manager.create_agent(sample_session.session_id, sample_agent) + assert len(result) == 5 + + +def test_list_messages_with_limit(file_manager, sample_session, sample_agent): + """Test listing messages with limit.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with limit + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, limit=3) + + assert len(result) == 3 - # Update nonexistent message - with pytest.raises(SessionException): - file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) +def test_list_messages_with_offset(file_manager, sample_session, sample_agent): + """Test listing messages with offset.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) -class TestFileSessionManagerErrorHandling: - """Tests for error handling scenarios.""" + # Create multiple messages + for i in range(10): + message = SessionMessage( + message={ + "role": "user", + "content": [ContentBlock(text=f"Message {i}")], + }, + message_id=i, + ) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List with offset + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id, offset=5) + + assert len(result) == 5 + + +def test_list_messages_with_new_agent(file_manager, sample_session, sample_agent): + """Test listing messages with new agent.""" + # Create session and agent + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + result = file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + +def test_update_message(file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + file_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message) - def test_corrupted_json_file(self, file_manager, temp_dir): - """Test handling of corrupted JSON files.""" - # Create a corrupted session file - session_path = os.path.join(temp_dir, "session_test") - os.makedirs(session_path, exist_ok=True) - session_file = os.path.join(session_path, "session.json") + # Update message + sample_message.message["content"] = [ContentBlock(text="Updated content")] + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) - with open(session_file, "w") as f: - f.write("invalid json content") + # Verify update + result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id) + assert result.message["content"][0]["text"] == "Updated content" - # Should raise SessionException - with pytest.raises(SessionException, match="Invalid JSON"): - file_manager._read_file(session_file) - def test_permission_error_handling(self, file_manager): - """Test handling of permission errors.""" - with patch("builtins.open", side_effect=PermissionError("Access denied")): - session = Session(session_id="test", session_type=SessionType.AGENT) +def test_update_nonexistent_message(file_manager, sample_session, sample_agent, sample_message): + """Test updating a message.""" + # Create session, agent, and message + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) - with pytest.raises(SessionException): - file_manager.create_session(session) + # Update nonexistent message + with pytest.raises(SessionException): + file_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +def test_corrupted_json_file(file_manager, temp_dir): + """Test handling of corrupted JSON files.""" + # Create a corrupted session file + session_path = os.path.join(temp_dir, "session_test") + os.makedirs(session_path, exist_ok=True) + session_file = os.path.join(session_path, "session.json") + + with open(session_file, "w") as f: + f.write("invalid json content") + + # Should raise SessionException + with pytest.raises(SessionException, match="Invalid JSON"): + file_manager._read_file(session_file) + + +def test_permission_error_handling(file_manager): + """Test handling of permission errors.""" + with patch("builtins.open", side_effect=PermissionError("Access denied")): + session = Session(session_id="test", session_type=SessionType.AGENT) + + with pytest.raises(SessionException): + file_manager.create_session(session) + + +@pytest.mark.parametrize( + "session_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_session_path_invalid_session_id(session_id, file_manager): + with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"): + file_manager._get_session_path(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_agent_path_invalid_agent_id(agent_id, file_manager): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + file_manager._get_agent_path("session1", agent_id) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index fadd0db4b..71bff3050 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -332,3 +332,27 @@ def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sa # Update message with pytest.raises(SessionException): s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) + + +@pytest.mark.parametrize( + "session_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_session_path_invalid_session_id(session_id, s3_manager): + with pytest.raises(ValueError, match=f"session_id={session_id} | id cannot contain path separators"): + s3_manager._get_session_path(session_id) + + +@pytest.mark.parametrize( + "agent_id", + [ + "a/../b", + "a/b", + ], +) +def test__get_agent_path_invalid_agent_id(agent_id, s3_manager): + with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): + s3_manager._get_agent_path("session1", agent_id) diff --git a/tests/strands/test_identifier.py b/tests/strands/test_identifier.py new file mode 100644 index 000000000..df673baa8 --- /dev/null +++ b/tests/strands/test_identifier.py @@ -0,0 +1,17 @@ +import pytest + +from strands import _identifier + + +@pytest.mark.parametrize("type_", list(_identifier.Identifier)) +def test_validate(type_): + tru_id = _identifier.validate("abc", type_) + exp_id = "abc" + assert tru_id == exp_id + + +@pytest.mark.parametrize("type_", list(_identifier.Identifier)) +def test_validate_invalid(type_): + id_ = "a/../b" + with pytest.raises(ValueError, match=f"{type_.value}={id_} | id cannot contain path separators"): + _identifier.validate(id_, type_) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 246879da7..e490c7bb0 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1064,7 +1064,7 @@ async def _run_context_injection_test(context_tool: AgentTool): "content": [ {"text": "Tool 'context_tool' (ID: test-id)"}, {"text": "injected agent 'test_agent' processed: some_message"}, - {"text": "context agent 'test_agent'"} + {"text": "context agent 'test_agent'"}, ], "toolUseId": "test-id", } @@ -1151,7 +1151,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> dict: assert len(tool_results) == 1 tool_result = tool_results[0] - + # Should get a validation error because tool_context is required but not provided assert tool_result["status"] == "error" assert "tool_context" in tool_result["content"][0]["text"].lower() @@ -1173,10 +1173,7 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: tool_use={ "toolUseId": "test-id-2", "name": "context_tool", - "input": { - "message": "some_message", - "tool_context": "my_custom_context_string" - }, + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, }, invocation_state={ "agent": Agent(name="test_agent"),