Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/strands/_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Strands identifier utilities."""

import enum
import os


class Identifier(enum.Enum):
"""Strands identifier types."""

AGENT = "agent"
SESSION = "session"


def validate(id_: str, type_: Identifier) -> str:
"""Validate strands id.
Args:
id_: Id to validate.
type_: Type of the identifier (e.g., session id, agent id, etc.)
Returns:
Validated id.
Raises:
ValueError: If id contains path separators.
"""
if os.path.basename(id_) != id_:
raise ValueError(f"{type_.value}_id={id_} | id cannot contain path separators")

return id_
6 changes: 5 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from opentelemetry import trace as trace_api
from pydantic import BaseModel

from .. import _identifier
from ..event_loop.event_loop import event_loop_cycle, run_tool
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
from ..hooks import (
Expand Down Expand Up @@ -249,12 +250,15 @@ def __init__(
Defaults to None.
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.

Raises:
ValueError: If agent id contains path separators.
"""
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
self.messages = messages if messages is not None else []

self.system_prompt = system_prompt
self.agent_id = agent_id or _DEFAULT_AGENT_ID
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

Expand Down
27 changes: 23 additions & 4 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
from typing import Any, Optional, cast

from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from .repository_session_manager import RepositorySessionManager
Expand Down Expand Up @@ -40,8 +41,9 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
"""Initialize FileSession with filesystem storage.

Args:
session_id: ID for the session
storage_dir: Directory for local filesystem storage (defaults to temp dir)
session_id: ID for the session.
ID is not allowed to contain path separators (e.g., a/b).
storage_dir: Directory for local filesystem storage (defaults to temp dir).
**kwargs: Additional keyword arguments for future extensibility.
"""
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
Expand All @@ -50,12 +52,29 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
super().__init__(session_id=session_id, session_repository=self)

def _get_session_path(self, session_id: str) -> str:
"""Get session directory path."""
"""Get session directory path.

Args:
session_id: ID for the session.

Raises:
ValueError: If session id contains a path separator.
"""
session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION)
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")

def _get_agent_path(self, session_id: str, agent_id: str) -> str:
"""Get agent directory path."""
"""Get agent directory path.

Args:
session_id: ID for the session.
agent_id: ID for the agent.

Raises:
ValueError: If session id or agent id contains a path separator.
"""
session_path = self._get_session_path(session_id)
agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT)
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")

def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
Expand Down
23 changes: 21 additions & 2 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError

from .. import _identifier
from ..types.exceptions import SessionException
from ..types.session import Session, SessionAgent, SessionMessage
from .repository_session_manager import RepositorySessionManager
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(

Args:
session_id: ID for the session
ID is not allowed to contain path separators (e.g., a/b).
bucket: S3 bucket name (required)
prefix: S3 key prefix for storage organization
boto_session: Optional boto3 session
Expand Down Expand Up @@ -79,12 +81,29 @@ def __init__(
super().__init__(session_id=session_id, session_repository=self)

def _get_session_path(self, session_id: str) -> str:
"""Get session S3 prefix."""
"""Get session S3 prefix.

Args:
session_id: ID for the session.

Raises:
ValueError: If session id contains a path separator.
"""
session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION)
return f"{self.prefix}/{SESSION_PREFIX}{session_id}/"

def _get_agent_path(self, session_id: str, agent_id: str) -> str:
"""Get agent S3 prefix."""
"""Get agent S3 prefix.

Args:
session_id: ID for the session.
agent_id: ID for the agent.

Raises:
ValueError: If session id or agent id contains a path separator.
"""
session_path = self._get_session_path(session_id)
agent_id = _identifier.validate(agent_id, _identifier.Identifier.AGENT)
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"

def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
Expand Down
12 changes: 12 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_impo
assert tru_tool_names == exp_tool_names


@pytest.mark.parametrize(
"agent_id",
[
"a/../b",
"a/b",
],
)
def test_agent__init__invalid_id(agent_id):
with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
Agent(agent_id=agent_id)


def test_agent__call__(
mock_model,
system_prompt,
Expand Down
Loading
Loading