CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Inductor][CPP] Optimize WOQ INT8 wgt dequant in AMX GEMM template #136630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor][CPP] Optimize WOQ INT8 wgt dequant in AMX GEMM template #136630
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/136630
Note: Links to docs will display an error until the docs builds have been completed. âś… No FailuresAs of commit b8b7923 with merge base failed to retrieve merge base, please contact dev infra: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
auto b_int8 = at::vec::Vectorized<int8_t>::loadu(src, static_cast<int64_t>(32)); | ||
auto b_bf16 = at::vec::convert<{{input_t}}>(b_int8); | ||
b_bf16.store(dst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I didn't know this was possible with at::vec
, since the conversion happens from int8 -> int32 -> fp32 -> bf16
. Looks like at::vec::convert
can use multiple intermediate vector registers (two in this case for holding int32 & fp32 values).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leslie-fang-intel, do you know if such a change is possible for the AVX512 micro-kernel as well, so that it could load multiple vector registers of B
at a time? I mean, by using at::vec
, and not intrinsics. Thanks!
Currently, it loads only one FP32 vector register of B at a time

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I guess so. Probably something like
auto b_int8 = at::vec::Vectorized<int8_t>::loadu(src, static_cast<int64_t>(16)); // load first 128 bits of 16 X int8
auto b_fp32 = convert_int8_to_float<int8_t>(b_int8); // CVT to 16 X FP32
Could you take a try if it benefits the AVX Micro GEMM performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this one is similar to the current implementation.
@leslie-fang-intel, could there have been some copy-paste error pertaining to the perf data on |
I guess it may due to the system difference. Re-run it still with similar performance on my test system. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary
Optimize the WOQ int8 AMX performance by changing the int8 -> bf16 conversion.
Earlier, 16 int8 elements were being loaded at a time & converted to 16 BF16 elements.
With this change, 32 int8 elements will be loaded at a time, and converted to a cache-line of 32 BF16 elements more efficiently.
Performance before
Performance after this PR
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang