Skip to content

Conversation

@kmehant
Copy link
Collaborator

@kmehant kmehant commented Nov 24, 2025

In this PR, we add CP support for mamba by swapping the HF mamba module with cp implementation from.

Results

Legend

abbreviation meaning
cp context parallel degree
ep expert parallel degree
dp data parallel degree
gas gradient accumulation steps
ebs effective batch size
s sequence length is

Ablations

Parity Experiments

model experiment setting loss tps per gpu
ibm-granite/granite-4.0-h-tiny cp8-ebs4-s8192-gas1 0.8059140625 973.6
ibm-granite/granite-4.0-h-tiny cp8-ebs4-s8192-gas1-ep8 0.80224609375 2367.6
ibm-granite/granite-4.0-h-tiny cp8-ebs4-s8192-gas2 0.8059765625 NA
ibm-granite/granite-4.0-h-tiny cp4-dp2-ebs4-s8192-gas1 0.802953125 953.4
ibm-granite/granite-4.0-h-tiny cp1-dp4-ep4-ebs4-s8192-gas1 0.7967056884765625 2576

Long Context (sequence length is 131072 (128k))

model experiment setting tps per gpu GPU memory util ratio
ibm-granite/granite-4.0-h-tiny cp8-ebs1-s131072-gas1-ep8 1462.8 0.5140136719
ibm-granite/granite-4.0-h-small cp8-ebs1-s131072-gas1-ep8 682.7 0.9887207031

Training Resumption

settings used: mk-cp8-ebs4-s8192-gas1

Screenshot 2025-11-28 at 1 25 03 PM

Summary of external dependencies

fsdp2-nov https://github.com/kmehant/transformers.git (additional changes not in upstream)

Changes:

  1. Preparing batches to be compatible for CP if CP is enabled.
  2. Wrapping training loop with torch cp context
  3. Preparing shift labels for correct loss calculation
  4. Specific loss reduction when combination of cp and dp is used.
  5. model saving fix

fsdp2-fix https://github.com/kmehant/accelerate.git (additional changes not in upstream)

Changes:

  1. Mixed precision fix when using FSDP2

fsdp2-fix https://github.com/kmehant/accelerate.git (additional changes not in upstream)

Changes:

  1. Mixed precision fix when using FSDP2

mamba-cp https://github.com/kmehant/fms-acceleration.git (will be main after merging #164)

  1. Enables CP for mamba layers to go hand in hand with self attention CP

mamba-cp https://github.com/garrett361/mamba (Thanks to Garrett)

  1. Mamba_ssm kernels must be installed from this fork and branch to leverage CP.

Summary of PR merged into HF repos to enable CP and FSDP2.

  1. feat: add ignored_params support for fsdp2 huggingface/accelerate#3731
  2. fix: CPU RAM efficient loading for nd or HSDP parallelisms huggingface/accelerate#3740
  3. feat: allow mixed precision policy as dtype huggingface/accelerate#3751
  4. refactor: nit change for get_parameters_from_modules (code debt) huggingface/accelerate#3815
  5. nit: needed sanity checks for fsdp2 huggingface/accelerate#3499
  6. feat: support tensor parallel & Data loader huggingface/accelerate#3173 (dataloader part is reused for CP)
  7. fix: fsdp sharded state dict wont work for save_only_model knob huggingface/transformers#36627

@kmehant kmehant changed the title CP support for mamba layer feat: CP support for mamba layer Nov 24, 2025
@kmehant kmehant force-pushed the mamba-cp branch 2 times, most recently from 3c0f767 to 1990451 Compare November 27, 2025 17:40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this?

Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
@kmehant kmehant merged commit d451073 into foundation-model-stack:main Nov 28, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants