CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Simplify & rectify dequantized B buffer loading for AMX GEMM micro-kernel for WoQ int8 case #140258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140258
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit 709adfe with merge base fa63276 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
BTW: I think horizontal transverse doesn't work well with this cache optimization cc @jgong5 @chunyuan-w |
Hi, would the horizontal traverse strategy complement the existing AMX GEMM micro-kernel template (by conditionally using it), or would it replace it? Thanks! |
I think we will use it conditionally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, please do not assume the B is contiguous.
As suggested by @leslie-fang-intel in https://github.com/leslie-fang-intel in /pytorch/commit/4c83e4e75138e8fa6e0d58438f75b7718dc8a0cc#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537
This case cannot be covered by the current UTs, since it hasn't been implemented
This case can't be tested, though, as N != block_n case has not been implemented.
Don't assume weight-packing at GEMM template level
Its value would also be known at runtime, so it wouldn't affect performance
Successfully rebased |
3dffe41
to
709adfe
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rnel for WoQ int8 case (pytorch#140258) As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of `B` tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements since `B` matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance. Also rectified the size of the dequantized B buffer. Fixes pytorch#140208. A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function Pull Request resolved: pytorch#140258 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
…rnel for WoQ int8 case (pytorch#140258) As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of `B` tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements since `B` matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance. Also rectified the size of the dequantized B buffer. Fixes pytorch#140208. A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function Pull Request resolved: pytorch#140258 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
…rnel for WoQ int8 case (pytorch#140258) As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of `B` tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements since `B` matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance. Also rectified the size of the dequantized B buffer. Fixes pytorch#140208. A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function Pull Request resolved: pytorch#140258 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel
As suggested by @leslie-fang-intel in leslie-fang-intel@4c83e4e#diff-139642bd981df977f70f4c18c1c34bd1a85c1d6b9ffa06aaa98426ed83942a31R537 - all elements of
B
tiles (not referring to AMX tiles, but the tiles at the granularity of the micro-kernel) have contiguous elements sinceB
matrix is pre-packed, so dequantized buffer loading logic can be simplified. While the previous approach kept elements to be loaded into a B AMX tile contiguous, the new approach doesn't entail any performance penalty either because that data is already in L1D, so loading AMX tiles from non-contiguous dequantized B elements doesn't adversely affect performance.Also rectified the size of the dequantized B buffer.
Fixes #140208.
A subsequent PR will factor out caching of dequantized int8 weights into a separate codegen function
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov