Skip to content

Commit db08634

Browse files
committed
remove flash-attn requirement && fix bug in inference and fast mode
1 parent ea33ebe commit db08634

File tree

8 files changed

+63
-44
lines changed

8 files changed

+63
-44
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ wallx = [
129129
"peft==0.17.1",
130130
"scipy==1.15.3",
131131
"torchdiffeq==0.2.5",
132-
"qwen_vl_utils==0.0.11",
133-
"flash-attn==2.7.4.post1"
132+
"qwen_vl_utils==0.0.11"
134133
]
135134
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
136135
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]

src/lerobot/policies/wall_x/configuration_wall_x.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ class WallXConfig(PreTrainedConfig):
5353
# Pretrained model paths
5454
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
5555

56+
# Tokenizer settings
57+
action_tokenizer_path: str | None = "physical-intelligence/fast"
58+
5659
# Action prediction mode: "diffusion" or "fast"
5760
prediction_mode: str = "diffusion"
5861

59-
# Tokenizer settings
60-
use_fast_tokenizer: bool = False # True: train FAST, False: train Flow
61-
action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode)
62-
62+
# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
63+
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
64+
attn_implementation: str = "eager"
6365

6466
# ==================== Optimizer Presets ====================
6567
optimizer_lr: float = 2e-5
@@ -87,11 +89,16 @@ def __post_init__(self):
8789
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
8890
)
8991

90-
# Sync prediction_mode with use_fast_tokenizer
91-
if self.use_fast_tokenizer:
92-
self.prediction_mode = "fast"
92+
# Assign use_fast_tokenizer based on prediction_mode
93+
if self.prediction_mode == "fast":
94+
self.use_fast_tokenizer = True
95+
elif self.prediction_mode == "diffusion":
96+
self.use_fast_tokenizer = False
97+
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
9398
else:
94-
self.prediction_mode = "diffusion"
99+
raise ValueError(
100+
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
101+
)
95102

96103
def validate_features(self) -> None:
97104
"""Validate and set up input/output features."""

src/lerobot/policies/wall_x/constant.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
Wall-X Constants and Configuration Data.
1919
"""
2020

21-
from lerobot.utils.constants import OBS_STATE, OBS_IMAGES, ACTION
22-
2321
CAMERA_NAME_MAPPING = {
2422
"face_view": "front view",
2523
"left_wrist_view": "left wrist view",

src/lerobot/policies/wall_x/modeling_wall_x.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,22 @@
6363
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
6464
from lerobot.utils.constants import ACTION, OBS_STATE
6565

66-
from lerobot.policies.wall_x.utils import *
67-
from lerobot.policies.wall_x.constant import *
66+
from lerobot.policies.wall_x.utils import (
67+
replace_action_token,
68+
preprocesser_call,
69+
get_wallx_normal_text,
70+
process_grounding_points,
71+
)
72+
from lerobot.policies.wall_x.constant import (
73+
MODEL_TYPE,
74+
TOKENIZER_MAX_LENGTH,
75+
PRIORITY_ORDER,
76+
GENERATE_SUBTASK_RATIO,
77+
RESOLUTION,
78+
MAX_PIXELS,
79+
MIN_PIXELS,
80+
IMAGE_FACTOR,
81+
)
6882
from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
6983
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
7084
Qwen2_5_VLForConditionalGeneration,
@@ -261,6 +275,7 @@ def from_pretrained(
261275
pretrained_name_or_path,
262276
config=None,
263277
action_tokenizer_path=None,
278+
attn_implementation: str = 'eager',
264279
cache_dir: str | PathLike | None = None,
265280
force_download: bool = False,
266281
local_files_only: bool = False,
@@ -276,6 +291,7 @@ def from_pretrained(
276291
pretrained_model_path (str): Model directory path containing model.safetensors file
277292
config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
278293
action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
294+
attn_implementation (str, optional): Attention implementation, if None will load from default config
279295
**kwargs: Additional arguments
280296
281297
Returns:
@@ -292,14 +308,18 @@ def from_pretrained(
292308
strict=strict,
293309
**kwargs,
294310
)
311+
if attn_implementation is not None:
312+
config._attn_implementation = attn_implementation
295313
processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True)
296314
if action_tokenizer_path is not None:
297-
processor.action_processor = AutoProcessor.from_pretrained(
315+
action_tokenizer = AutoProcessor.from_pretrained(
298316
action_tokenizer_path, trust_remote_code=True
299317
)
300-
318+
processor.action_processor = action_tokenizer
319+
else:
320+
action_tokenizer = None
301321
# Initialize model with configuration and processor
302-
model = cls(config, processor=processor, **kwargs)
322+
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
303323

304324
# Resize token embeddings to match processor tokenizer vocabulary size
305325
model.resize_token_embeddings(len(processor.tokenizer))
@@ -379,6 +399,7 @@ def __init__(
379399
self.flow_loss_weight = flow_loss_weight
380400
self.use_fast_tokenizer = use_fast_tokenizer
381401
self.processor = processor
402+
self.action_tokenizer = action_tokenizer
382403

383404
# Define action token IDs
384405
self.define_action_token_id()
@@ -1279,7 +1300,7 @@ def predict(
12791300
if labels is not None:
12801301
labels = labels[:, split_pos + 3 :]
12811302
else:
1282-
raise Warning(
1303+
raise ValueError(
12831304
"input_ids does not contain the generation prompt tokens <|im_start|>assistant"
12841305
)
12851306

@@ -1826,7 +1847,7 @@ def __init__(self, config: WallXConfig):
18261847
self.config = config
18271848

18281849
# Initialize the wall-x model
1829-
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
1850+
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path, attn_implementation=config.attn_implementation)
18301851
self.model.to(config.device)
18311852
self.model.to_bfloat16_for_selected_params()
18321853

@@ -1950,22 +1971,23 @@ def preprocess_inputs(
19501971
], dim=-1)
19511972

19521973
# ==================== PROCESS ACTIONS ====================
1953-
action = batch[ACTION] # (batch_size, chunk_size, action_dim)
1954-
if action.dim() == 2:
1955-
action = action.unsqueeze(1)
1956-
dof_mask = (~torch.isnan(action)).float()
1957-
action = action.nan_to_num(nan=0.0)
1958-
1959-
if action.shape[-1] != 20:
1960-
pad_size = 20 - action.shape[-1]
1961-
action = torch.cat([
1962-
action,
1963-
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
1964-
], dim=-1)
1965-
dof_mask = torch.cat([
1966-
dof_mask,
1967-
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
1968-
], dim=-1)
1974+
action = batch.get(ACTION, None) # (batch_size, chunk_size, action_dim)
1975+
if action is not None:
1976+
if action.dim() == 2:
1977+
action = action.unsqueeze(1)
1978+
dof_mask = (~torch.isnan(action)).float()
1979+
action = action.nan_to_num(nan=0.0)
1980+
1981+
if action.shape[-1] != 20:
1982+
pad_size = 20 - action.shape[-1]
1983+
action = torch.cat([
1984+
action,
1985+
torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)
1986+
], dim=-1)
1987+
dof_mask = torch.cat([
1988+
dof_mask,
1989+
torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device)
1990+
], dim=-1)
19691991

19701992
# ==================== ACTION TOKEN REPLACEMENT ====================
19711993
all_texts = replace_action_token(

src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,6 @@ def __init__(
231231
self.attention_moe = attention_moe
232232
self.mlp_moe = mlp_moe
233233

234-
# Validate the correctness of rotary position embeddings parameters
235-
# BC: if there is a 'type' field, move it to 'rope_type'.
236-
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
237-
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
238-
# TODO: @raushan update config in the hub
239234
if self.rope_scaling is not None and "type" in self.rope_scaling:
240235
if self.rope_scaling["type"] == "mrope":
241236
self.rope_scaling["type"] = "default"

src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
15061506
dtype (`torch.dtype`):
15071507
The dtype to use for the 4D attention mask.
15081508
device (`torch.device`):
1509-
The device to plcae the 4D attention mask on.
1509+
The device to place the 4D attention mask on.
15101510
cache_position (`torch.Tensor`):
15111511
Indices depicting the position of the input sequence tokens in the sequence.
15121512
batch_size (`torch.Tensor`):

tests/policies/wall_x/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

tests/policies/wall_x/test_wallx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from lerobot.policies.wall_x import ( # noqa: E402
3232
WallXConfig,
3333
WallXPolicy,
34-
make_wall_x_pre_post_processors, # noqa: E402
34+
make_wall_x_pre_post_processors,
3535
)
3636
from lerobot.utils.random_utils import set_seed # noqa: E402
3737

0 commit comments

Comments
 (0)