-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Feat] Layerwise casting hook #3427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: a-r-r-o-w <[email protected]>
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far so good, thanks for stating this work !
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work @sayakpaul , just a small edge case
| def test_layerwise_upcasting_inference(self, storage_dtype, compute_dtype, skip_modules_pattern=None): | ||
| test_model = ModelForTest() | ||
| inputs = torch.randn(2, 3) | ||
| inputs = inputs.to(compute_dtype) if inputs.dtype == torch.float32 else inputs | ||
|
|
||
| attach_layerwise_casting_hooks( | ||
| test_model, | ||
| storage_dtype=storage_dtype, | ||
| compute_dtype=compute_dtype, | ||
| skip_modules_pattern=skip_modules_pattern, | ||
| ) | ||
| patterns_to_check = skip_modules_pattern if skip_modules_pattern else None | ||
| self.check_dtype_for_layerwise_upcasting(test_model, storage_dtype, compute_dtype, patterns_to_check) | ||
|
|
||
| with torch.no_grad(): | ||
| _ = test_model(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a case where a module shouldn't be downcasted at all even if we don't pass skip_modules_pattern -> one that is not in the SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean we call
attach_layerwise_casting_hooks(
test_model,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=None,
)and the test_model isn't downcasted at all? Not sure how we could do that. Could you elaborate more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, a model that contains a RMSNorm shouldn't see its weights being downcasted after applying attach_layerwise_casting_hooks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have any default modules to not skip. I don't know if we can do what you're asking for without defining a default module name that will never get downcasted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have that no ? SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay. Sorry for the confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
@SunMarc seems like there's something wrong with the CI: |
Yeah, should be fixed in this PR ! Thanks for the heads-up. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits !
What does this PR do?
Introduces a layerwise casting method that lets users run a model in separate storage and compute dtypes, leading to memory savings. See internal thread:
https://huggingface.slack.com/archives/C01Q6JPP6NA/p1739198698861009
I tested a sufficiently large dummy model with it and here are some numbers:
script
I haven't updated the tests yet as this PR is for gathering feedback to know if it's headed in the right direction.
TODOs
skip_modules_patternarg