6363from lerobot .policies .wall_x .configuration_wall_x import WallXConfig
6464from lerobot .utils .constants import ACTION , OBS_STATE
6565
66- from lerobot .policies .wall_x .utils import *
67- from lerobot .policies .wall_x .constant import *
66+ from lerobot .policies .wall_x .utils import (
67+ replace_action_token ,
68+ preprocesser_call ,
69+ get_wallx_normal_text ,
70+ process_grounding_points ,
71+ )
72+ from lerobot .policies .wall_x .constant import (
73+ MODEL_TYPE ,
74+ TOKENIZER_MAX_LENGTH ,
75+ PRIORITY_ORDER ,
76+ GENERATE_SUBTASK_RATIO ,
77+ RESOLUTION ,
78+ MAX_PIXELS ,
79+ MIN_PIXELS ,
80+ IMAGE_FACTOR ,
81+ )
6882from lerobot .policies .wall_x .qwen_model .configuration_qwen2_5_vl import Qwen2_5_VLConfig
6983from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
7084 Qwen2_5_VLForConditionalGeneration ,
@@ -261,6 +275,7 @@ def from_pretrained(
261275 pretrained_name_or_path ,
262276 config = None ,
263277 action_tokenizer_path = None ,
278+ attn_implementation : str = 'eager' ,
264279 cache_dir : str | PathLike | None = None ,
265280 force_download : bool = False ,
266281 local_files_only : bool = False ,
@@ -276,6 +291,7 @@ def from_pretrained(
276291 pretrained_model_path (str): Model directory path containing model.safetensors file
277292 config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path
278293 action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config
294+ attn_implementation (str, optional): Attention implementation, if None will load from default config
279295 **kwargs: Additional arguments
280296
281297 Returns:
@@ -292,14 +308,18 @@ def from_pretrained(
292308 strict = strict ,
293309 ** kwargs ,
294310 )
311+ if attn_implementation is not None :
312+ config ._attn_implementation = attn_implementation
295313 processor = AutoProcessor .from_pretrained (pretrained_name_or_path , use_fast = True )
296314 if action_tokenizer_path is not None :
297- processor . action_processor = AutoProcessor .from_pretrained (
315+ action_tokenizer = AutoProcessor .from_pretrained (
298316 action_tokenizer_path , trust_remote_code = True
299317 )
300-
318+ processor .action_processor = action_tokenizer
319+ else :
320+ action_tokenizer = None
301321 # Initialize model with configuration and processor
302- model = cls (config , processor = processor , ** kwargs )
322+ model = cls (config , processor = processor , action_tokenizer = action_tokenizer , ** kwargs )
303323
304324 # Resize token embeddings to match processor tokenizer vocabulary size
305325 model .resize_token_embeddings (len (processor .tokenizer ))
@@ -379,6 +399,7 @@ def __init__(
379399 self .flow_loss_weight = flow_loss_weight
380400 self .use_fast_tokenizer = use_fast_tokenizer
381401 self .processor = processor
402+ self .action_tokenizer = action_tokenizer
382403
383404 # Define action token IDs
384405 self .define_action_token_id ()
@@ -1279,7 +1300,7 @@ def predict(
12791300 if labels is not None :
12801301 labels = labels [:, split_pos + 3 :]
12811302 else :
1282- raise Warning (
1303+ raise ValueError (
12831304 "input_ids does not contain the generation prompt tokens <|im_start|>assistant"
12841305 )
12851306
@@ -1826,7 +1847,7 @@ def __init__(self, config: WallXConfig):
18261847 self .config = config
18271848
18281849 # Initialize the wall-x model
1829- self .model = Qwen2_5_VLMoEForAction .from_pretrained (config .pretrained_name_or_path )
1850+ self .model = Qwen2_5_VLMoEForAction .from_pretrained (config .pretrained_name_or_path , attn_implementation = config . attn_implementation )
18301851 self .model .to (config .device )
18311852 self .model .to_bfloat16_for_selected_params ()
18321853
@@ -1950,22 +1971,23 @@ def preprocess_inputs(
19501971 ], dim = - 1 )
19511972
19521973 # ==================== PROCESS ACTIONS ====================
1953- action = batch [ACTION ] # (batch_size, chunk_size, action_dim)
1954- if action .dim () == 2 :
1955- action = action .unsqueeze (1 )
1956- dof_mask = (~ torch .isnan (action )).float ()
1957- action = action .nan_to_num (nan = 0.0 )
1958-
1959- if action .shape [- 1 ] != 20 :
1960- pad_size = 20 - action .shape [- 1 ]
1961- action = torch .cat ([
1962- action ,
1963- torch .zeros (action .shape [0 ], action .shape [1 ], pad_size , device = action .device )
1964- ], dim = - 1 )
1965- dof_mask = torch .cat ([
1966- dof_mask ,
1967- torch .zeros (dof_mask .shape [0 ], dof_mask .shape [1 ], pad_size , device = dof_mask .device )
1968- ], dim = - 1 )
1974+ action = batch .get (ACTION , None ) # (batch_size, chunk_size, action_dim)
1975+ if action is not None :
1976+ if action .dim () == 2 :
1977+ action = action .unsqueeze (1 )
1978+ dof_mask = (~ torch .isnan (action )).float ()
1979+ action = action .nan_to_num (nan = 0.0 )
1980+
1981+ if action .shape [- 1 ] != 20 :
1982+ pad_size = 20 - action .shape [- 1 ]
1983+ action = torch .cat ([
1984+ action ,
1985+ torch .zeros (action .shape [0 ], action .shape [1 ], pad_size , device = action .device )
1986+ ], dim = - 1 )
1987+ dof_mask = torch .cat ([
1988+ dof_mask ,
1989+ torch .zeros (dof_mask .shape [0 ], dof_mask .shape [1 ], pad_size , device = dof_mask .device )
1990+ ], dim = - 1 )
19691991
19701992 # ==================== ACTION TOKEN REPLACEMENT ====================
19711993 all_texts = replace_action_token (
0 commit comments