CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Inductor][CPP] Cache weight tiles in L1D for AMX int8 WoQ GEMM #136688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor][CPP] Cache weight tiles in L1D for AMX int8 WoQ GEMM #136688
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136688
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3b3f56e with merge base e6e140c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d9d6a90
to
1b3c163
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we need to add assert here
for (int64_t n = 0; n < N; n += {{block_n}}) { |
N == block_n
In the CPP template, it's being done for all dtypes - while calling a micro-kernel, cache blocking size for N dimension is equal to the register blocking size for N. I added a comment in the code, so that we may make necessary changes (buffer allocation & index computation) if that'd cease to be the case. Thanks! |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
83a7699
to
d0e7815
Compare
d0e7815
to
6d4497d
Compare
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor-periodic / cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks / test (aot_eager_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rch#136688) # Summary The AMX ISA based GEMM micro-kernel template for int8 weight-only quantization (BF16 activation, int8 weights) should cache dequantized weights (int8 -> int32 -> fp32 -> bf16) so that they would not have to be dequantized again in subsequent calls to the _inner-kernel_ that uses the same weights. This change leverages the fact that even for BF16 x BF16 GEMM template, cache-blocking ensures that `Nr * Kc` weight elements are cached in L1D cache (more info [here](https://static.sched.com/hosted_files/pytorch2024/59/TorchInductor%20CPU%20Backend%20Advancements%20-%20New%20Features%20and%20Performance%20Improvements_20240915.pdf)). Here, `Nr` is the register blocking size for `N` dimension (at the granularity of the GEMM micro-kernel, it's currently also the cache blocking size for `N` dimension, although that may change in the future), and `Kc` is the cache blocking size for `K` dimension. The figure below is from the document linked above - <img width="476" alt="image" src="https://github.com/user-attachments/assets/e23e5476-d910-46d1-a9b3-cbf77de76d94"> ## Performance data Collected on 48 physical cores of one socket of Intel Xeon Platinum 8468H (Xeon SP 4th gen). Intel OpenMP & tcmalloc were preloaded. |M | N | K | Latency with ATen _weight_int8pack_mm | Latency with codegened templated GEMM (current main branch) | Latency with codegened templated GEMM (this PR) | |-----|-----|-----|------|----------|----| |4096|4096|4096| 45.844 ms | 9.322 ms| 5.2181 ms | |4096|11008|4096| 127.618 ms |24.6258 ms | 13.6046 ms| |4096|4096|11008| 121.953 ms | 25.4692 ms | 10.2669 ms | |4096|32000|4096| 478.450 ms| 75.3942 ms | 48.21 ms | Pull Request resolved: pytorch#136688 Approved by: https://github.com/jgong5
…ntized (#139906) @frost-intel discovered that some Inductor auto-tuning UTs for CPU are currently broken on machines supporting AMX ISA. That's because in #136688, I had reverted a change in the AMX GEMM micro-kernel that was introduced in #131887, but it looks like some other implementations introduced after the aforementioned change rely upon it, so it should not have been reverted. Added a fix. Ideally, a CI machine that supports AMX should cover these UTs (test/inductor/test_cpu_select_algorithm.py). We do have at least one CI machines that support AMX. Pull Request resolved: #139906 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
@pytorchbot revert -m "correctness issue in #140208" |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot revert -m "correctness issue in #140208" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 136688 failedReason: Command
Details for Dev Infra teamRaised by workflow job |
|
Also, it seems we need more UTs to cover this feature... |
An existing UT caught it :( |
Thanks for the info! The size of the buffer should not have been multiplied with 2
Thanks! The patch assumes that all elements of a ![]()
@chunyuan-w explained that the weights are being pre-packed to let each tile being accessed contiguously. So, I'm guessing when @jgong5 will add |
…ntized (pytorch#139906) @frost-intel discovered that some Inductor auto-tuning UTs for CPU are currently broken on machines supporting AMX ISA. That's because in pytorch#136688, I had reverted a change in the AMX GEMM micro-kernel that was introduced in pytorch#131887, but it looks like some other implementations introduced after the aforementioned change rely upon it, so it should not have been reverted. Added a fix. Ideally, a CI machine that supports AMX should cover these UTs (test/inductor/test_cpu_select_algorithm.py). We do have at least one CI machines that support AMX. Pull Request resolved: pytorch#139906 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
Summary
The AMX ISA based GEMM micro-kernel template for int8 weight-only quantization (BF16 activation, int8 weights) should cache dequantized weights (int8 -> int32 -> fp32 -> bf16) so that they would not have to be dequantized again in subsequent calls to the inner-kernel that uses the same weights.
This change leverages the fact that even for BF16 x BF16 GEMM template, cache-blocking ensures that
Nr * Kc
weight elements are cached in L1D cache (more info here). Here,Nr
is the register blocking size forN
dimension (at the granularity of the GEMM micro-kernel, it's currently also the cache blocking size forN
dimension, although that may change in the future), andKc
is the cache blocking size forK
dimension.The figure below is from the document linked above -
Performance data
Collected on 48 physical cores of one socket of Intel Xeon Platinum 8468H (Xeon SP 4th gen). Intel OpenMP & tcmalloc were preloaded.
cc @jgong5 @mingfeima @XiaobingSuper @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @rec