Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ intelrealsense = [
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]

# Policies
wallx = [
"torch==2.6.0",
"torchvision==0.21.0",
"torchaudio==2.6.0",
"transformers==4.49.0",
"accelerate==1.10.1",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
Expand Down Expand Up @@ -159,6 +170,7 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
Expand Down
2 changes: 2 additions & 0 deletions src/lerobot/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig

__all__ = [
"ACTConfig",
Expand All @@ -33,4 +34,5 @@
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
"WallXConfig",
]
20 changes: 18 additions & 2 deletions src/lerobot/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
Expand All @@ -61,7 +62,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:

Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".

Returns:
The policy class corresponding to the given name.
Expand Down Expand Up @@ -113,6 +114,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy

return XVLAPolicy
elif name == "wall_x":
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy

return WallXPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
Expand All @@ -130,7 +135,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.

Returns:
Expand Down Expand Up @@ -161,6 +166,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
Expand Down Expand Up @@ -344,6 +351,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
Expand All @@ -353,6 +361,14 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

elif isinstance(policy_cfg, WallXConfig):
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors

processors = make_wall_x_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

else:
try:
Expand Down
35 changes: 35 additions & 0 deletions src/lerobot/policies/wall_x/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# WALL-OSS

This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.

---

## Model Overview

| Feature | Description |
| -------------------- | ------------------------------------------------------------------------ |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
---

## Citation

If you use this work, please cite:

```bibtex
@article{zhai2025igniting,
title = {Igniting VLMs Toward the Embodied Space},
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
journal = {arXiv preprint arXiv:2509.11766},
year = {2025}
}
```

---

## License

This port follows the **Apache 2.0 License**.

21 changes: 21 additions & 0 deletions src/lerobot/policies/wall_x/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python

# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration_wall_x import WallXConfig
from .modeling_wall_x import WallXPolicy
from .processor_wall_x import make_wall_x_pre_post_processors

__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
169 changes: 169 additions & 0 deletions src/lerobot/policies/wall_x/configuration_wall_x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig


@PreTrainedConfig.register_subclass("wall_x")
@dataclass
class WallXConfig(PreTrainedConfig):
"""
Configuration class for Wall-X policy.

Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
It supports cross-embodiment robotic control through unified action representations.

This config supports multi-modal learning with vision, language, and action data.
"""

# ==================== Input / Output Structure ====================
n_obs_steps: int = 1
chunk_size: int = 32 # action_horizon in wall-x
n_action_steps: int = 32

# Action dimension - wall-x uses 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception

normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)

# ==================== Action Prediction ====================
# Pretrained model paths
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"

# Tokenizer settings
action_tokenizer_path: str | None = "physical-intelligence/fast"

# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"

# Attention Implementation, options: "eager", "flash_attention_2", "sdpa"
# NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation
attn_implementation: str = "eager"

# ==================== Optimizer Presets ====================
optimizer_lr: float = 2e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0

scheduler_warmup_steps: int = 1000
scheduler_decay_steps: int = 100000
scheduler_decay_lr: float = 1e-6

def __post_init__(self):
super().__post_init__()

# Input validation
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)

if self.prediction_mode not in ["diffusion", "fast"]:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)

# Assign use_fast_tokenizer based on prediction_mode
if self.prediction_mode == "fast":
self.use_fast_tokenizer = True
elif self.prediction_mode == "diffusion":
self.use_fast_tokenizer = False
self.action_tokenizer_path = None # disable action tokenizer for diffusion mode
else:
raise ValueError(
f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}"
)

def validate_features(self) -> None:
"""Validate and set up input/output features."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Wall-X policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)

if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)

if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)

def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)

def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)

@property
def observation_delta_indices(self) -> list:
return None

@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))

@property
def reward_delta_indices(self) -> None:
return None
41 changes: 41 additions & 0 deletions src/lerobot/policies/wall_x/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python

# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Wall-X Constants and Configuration Data.
"""

CAMERA_NAME_MAPPING = {
"face_view": "front view",
"left_wrist_view": "left wrist view",
"right_wrist_view": "right wrist view",
"move1_view": "move view",
"move2_view": "move view",
"wall_view": "wall view",
"top_view": "top view",
}

RESOLUTION = 256

# Parameters for preprocessing
MAX_PIXELS = 16384 * 28 * 28
MIN_PIXELS = 4 * 28 * 28
IMAGE_FACTOR = 28
PRIORITY_ORDER = None
GENERATE_SUBTASK_RATIO = 0.0
MODEL_TYPE = "qwen2_5"

TOKENIZER_MAX_LENGTH = 768
Loading