Skip to content

Commit 859a5af

Browse files
committed
session manager - prevent file path injection
1 parent 606f657 commit 859a5af

File tree

11 files changed

+491
-300
lines changed

11 files changed

+491
-300
lines changed

src/strands/agent/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..types.exceptions import ContextWindowOverflowException
4141
from ..types.tools import ToolResult, ToolUse
4242
from ..types.traces import AttributeValue
43+
from . import identifier
4344
from .agent_result import AgentResult
4445
from .conversation_manager import (
4546
ConversationManager,
@@ -249,12 +250,15 @@ def __init__(
249250
Defaults to None.
250251
session_manager: Manager for handling agent sessions including conversation history and state.
251252
If provided, enables session-based persistence and state management.
253+
254+
Raises:
255+
ValueError: If agent id contains path separators.
252256
"""
253257
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
254258
self.messages = messages if messages is not None else []
255259

256260
self.system_prompt = system_prompt
257-
self.agent_id = agent_id or _DEFAULT_AGENT_ID
261+
self.agent_id = identifier.validate(agent_id or _DEFAULT_AGENT_ID)
258262
self.name = name or _DEFAULT_AGENT_NAME
259263
self.description = description
260264

src/strands/agent/identifier.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Agent identifier utilities."""
2+
3+
import os
4+
5+
6+
def validate(agent_id: str) -> str:
7+
"""Validate agent id.
8+
9+
Args:
10+
agent_id: Id to validate.
11+
12+
Returns:
13+
Validated id.
14+
15+
Raises:
16+
ValueError: If id contains path separators.
17+
"""
18+
if os.path.basename(agent_id) != agent_id:
19+
raise ValueError(f"agent_id={agent_id} | agent id cannot contain path separators")
20+
21+
return agent_id

src/strands/session/file_session_manager.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import tempfile
88
from typing import Any, Optional, cast
99

10+
from ..agent.identifier import validate as validate_agent_id
1011
from ..types.exceptions import SessionException
1112
from ..types.session import Session, SessionAgent, SessionMessage
13+
from .identifier import validate as validate_session_id
1214
from .repository_session_manager import RepositorySessionManager
1315
from .session_repository import SessionRepository
1416

@@ -40,8 +42,9 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
4042
"""Initialize FileSession with filesystem storage.
4143
4244
Args:
43-
session_id: ID for the session
44-
storage_dir: Directory for local filesystem storage (defaults to temp dir)
45+
session_id: ID for the session.
46+
ID is not allowed to contain path separators (e.g., a/b).
47+
storage_dir: Directory for local filesystem storage (defaults to temp dir).
4548
**kwargs: Additional keyword arguments for future extensibility.
4649
"""
4750
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
@@ -50,12 +53,29 @@ def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs:
5053
super().__init__(session_id=session_id, session_repository=self)
5154

5255
def _get_session_path(self, session_id: str) -> str:
53-
"""Get session directory path."""
56+
"""Get session directory path.
57+
58+
Args:
59+
session_id: ID for the session.
60+
61+
Raises:
62+
ValueError: If session id contains a path separator.
63+
"""
64+
session_id = validate_session_id(session_id)
5465
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")
5566

5667
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
57-
"""Get agent directory path."""
68+
"""Get agent directory path.
69+
70+
Args:
71+
session_id: ID for the session.
72+
agent_id: ID for the agent.
73+
74+
Raises:
75+
ValueError: If session id or agent id contains a path separator.
76+
"""
5877
session_path = self._get_session_path(session_id)
78+
agent_id = validate_agent_id(agent_id)
5979
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
6080

6181
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:

src/strands/session/identifier.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Session identifier utilities."""
2+
3+
import os
4+
5+
6+
def validate(session_id: str) -> str:
7+
"""Validate session id.
8+
9+
Args:
10+
session_id: Id to validate.
11+
12+
Returns:
13+
Validated id.
14+
15+
Raises:
16+
ValueError: If id contains path separators.
17+
"""
18+
if os.path.basename(session_id) != session_id:
19+
raise ValueError(f"session_id={session_id} | session id cannot contain path separators")
20+
21+
return session_id

src/strands/session/s3_session_manager.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from botocore.config import Config as BotocoreConfig
99
from botocore.exceptions import ClientError
1010

11+
from ..agent.identifier import validate as validate_agent_id
1112
from ..types.exceptions import SessionException
1213
from ..types.session import Session, SessionAgent, SessionMessage
14+
from .identifier import validate as validate_session_id
1315
from .repository_session_manager import RepositorySessionManager
1416
from .session_repository import SessionRepository
1517

@@ -51,6 +53,7 @@ def __init__(
5153
5254
Args:
5355
session_id: ID for the session
56+
ID is not allowed to contain path separators (e.g., a/b).
5457
bucket: S3 bucket name (required)
5558
prefix: S3 key prefix for storage organization
5659
boto_session: Optional boto3 session
@@ -79,12 +82,29 @@ def __init__(
7982
super().__init__(session_id=session_id, session_repository=self)
8083

8184
def _get_session_path(self, session_id: str) -> str:
82-
"""Get session S3 prefix."""
85+
"""Get session S3 prefix.
86+
87+
Args:
88+
session_id: ID for the session.
89+
90+
Raises:
91+
ValueError: If session id contains a path separator.
92+
"""
93+
session_id = validate_session_id(session_id)
8394
return f"{self.prefix}/{SESSION_PREFIX}{session_id}/"
8495

8596
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
86-
"""Get agent S3 prefix."""
97+
"""Get agent S3 prefix.
98+
99+
Args:
100+
session_id: ID for the session.
101+
agent_id: ID for the agent.
102+
103+
Raises:
104+
ValueError: If session id or agent id contains a path separator.
105+
"""
87106
session_path = self._get_session_path(session_id)
107+
agent_id = validate_agent_id(agent_id)
88108
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"
89109

90110
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:

tests/strands/agent/test_agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,18 @@ def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_impo
250250
assert tru_tool_names == exp_tool_names
251251

252252

253+
@pytest.mark.parametrize(
254+
"agent_id",
255+
[
256+
"a/../b",
257+
"a/b",
258+
],
259+
)
260+
def test_agent__init__invalid_id(agent_id):
261+
with pytest.raises(ValueError):
262+
Agent(agent_id=agent_id)
263+
264+
253265
def test_agent__call__(
254266
mock_model,
255267
system_prompt,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from strands.agent import identifier
4+
5+
6+
def test_validate():
7+
tru_id = identifier.validate("abc")
8+
exp_id = "abc"
9+
assert tru_id == exp_id
10+
11+
12+
@pytest.mark.parametrize(
13+
"agent_id",
14+
[
15+
"a/../b",
16+
"a/b",
17+
],
18+
)
19+
def test_validate_invalid(agent_id):
20+
with pytest.raises(ValueError):
21+
identifier.validate(agent_id)

0 commit comments

Comments
 (0)