Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 7, 2025

What does this PR do?

Introduces a layerwise casting method that lets users run a model in separate storage and compute dtypes, leading to memory savings. See internal thread:
https://huggingface.slack.com/archives/C01Q6JPP6NA/p1739198698861009

I tested a sufficiently large dummy model with it and here are some numbers:

Batch Size Layerwise Casting Autocast Model Memory Footprint (MB) Peak Inference Memory Allocated (GB)
1 True False 428.23 1635.86
1 False True 1633.56 2475.92
4 True False 408.39 1639.11
4 False True 1633.56 2539.73
16 True False 408.39 1649.02
16 False True 1633.56 2848.93
script
import torch
from transformers import SiglipVisionModel
from accelerate.big_modeling import apply_layerwise_casting
from contextlib import nullcontext
import argparse
import json

SUPPORTED_PYTORCH_LAYERS = (
    torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
    torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
    torch.nn.Linear,
)

def get_memory_footprint(model, return_buffers=True):
    mem = sum([param.nelement() * param.element_size() for param in model.parameters()])
    if return_buffers:
        mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
        mem = mem + mem_bufs
    return mem


def get_dummy_inputs(batch_size=1, target_dtype=torch.float32):
    inputs = {"pixel_values": torch.randn(batch_size, 3, 384, 384).to("cuda", target_dtype)}
    return inputs

def check_linear_dtype(module, storage_dtype, compute_dtype):
    for name, submodule in module.named_modules():
        if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
            continue
        dtype_to_check = storage_dtype
        if getattr(submodule, "weight", None) is not None:
            assert submodule.weight.dtype == dtype_to_check, submodule.weight.dtype
        if getattr(submodule, "bias", None) is not None:
            assert submodule.bias.dtype == dtype_to_check, submodule.bias.dtype


@torch.no_grad()
@torch.inference_mode()
def run(model, inputs):
    model(**inputs)
    model_mem_footprint = get_memory_footprint(model) / 1024**2
    peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
    return {
        "model_mem_footprint_mb": model_mem_footprint,
        "peak_inference_memory_allocated_gb": peak_inference_memory_allocated_mb
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--apply_layerwise_casting", action="store_true")
    parser.add_argument("--use_autocast", action="store_true")
    args = parser.parse_args()

    model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to("cuda")
    
    if args.apply_layerwise_casting:
        inputs = get_dummy_inputs(batch_size=args.batch_size, target_dtype=torch.bfloat16)
        apply_layerwise_casting(model, storage_dtype=torch.float8_e5m2, compute_dtype=torch.bfloat16)
        context = nullcontext()
    elif args.use_autocast:
        inputs = get_dummy_inputs(batch_size=args.batch_size)
        context = torch.autocast("cuda", dtype=torch.bfloat16, enabled=True)
    
    with context:
        output = run(model, inputs)

    print(f"{args.batch_size=}, {args.apply_layerwise_casting=}, {args.use_autocast=}")
    print(json.dumps(output, indent=4))

I haven't updated the tests yet as this PR is for gathering feedback to know if it's headed in the right direction.

TODOs

  • Incorporate skip_modules_pattern arg
  • Docs
  • Tests

@sayakpaul sayakpaul requested a review from SunMarc March 7, 2025 12:51
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far so good, thanks for stating this work !

@sayakpaul sayakpaul marked this pull request as ready for review March 27, 2025 15:34
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @sayakpaul , just a small edge case

Comment on lines 427 to 442
def test_layerwise_upcasting_inference(self, storage_dtype, compute_dtype, skip_modules_pattern=None):
test_model = ModelForTest()
inputs = torch.randn(2, 3)
inputs = inputs.to(compute_dtype) if inputs.dtype == torch.float32 else inputs

attach_layerwise_casting_hooks(
test_model,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=skip_modules_pattern,
)
patterns_to_check = skip_modules_pattern if skip_modules_pattern else None
self.check_dtype_for_layerwise_upcasting(test_model, storage_dtype, compute_dtype, patterns_to_check)

with torch.no_grad():
_ = test_model(inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a case where a module shouldn't be downcasted at all even if we don't pass skip_modules_pattern -> one that is not in the SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING list.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean we call

attach_layerwise_casting_hooks(
    test_model,
    storage_dtype=storage_dtype,
    compute_dtype=compute_dtype,
    skip_modules_pattern=None,
)

and the test_model isn't downcasted at all? Not sure how we could do that. Could you elaborate more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, a model that contains a RMSNorm shouldn't see its weights being downcasted after applying attach_layerwise_casting_hooks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have any default modules to not skip. I don't know if we can do what you're asking for without defining a default module name that will never get downcasted.

Copy link
Member

@SunMarc SunMarc Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have that no ? SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Sorry for the confusion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@sayakpaul
Copy link
Member Author

@SunMarc
Copy link
Member

SunMarc commented Apr 18, 2025

@SunMarc seems like there's something wrong with the CI:
https://github.com/huggingface/accelerate/actions/runs/14529383877/job/40766557692?pr=3427#step:3:11

Yeah, should be fixed in this PR ! Thanks for the heads-up.

@sayakpaul sayakpaul requested a review from SunMarc April 21, 2025 08:51
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits !

@SunMarc SunMarc merged commit 6a9a615 into main Apr 22, 2025
28 checks passed
@SunMarc SunMarc deleted the layerwise-casting-hook branch April 22, 2025 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants