CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[PyTorch] Hook up fp16_gemv_trans to x86 fp16 GEMM #137918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is the first big milestone we've been building towards! (TODO: also hook it up to GEMV in the same way fp16_gemv_trans is hooked up) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137918
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 06b6c11 with merge base 86602a6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D64280688 |
This is the first big milestone we've been building towards! (TODO: also hook it up to GEMV in the same way fp16_gemv_trans is hooked up) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! ghstack-source-id: 247859556 Pull Request resolved: #137918
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have performance numbers?
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
it improves decoding performance about 5x for |
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
aten/src/ATen/native/CPUBlas.cpp
Outdated
// is to upconvert to fp32 and call sgemm. We can do better by | ||
// fusing the conversion. | ||
const bool fp16_gemv_trans_fast_path_would_be_beneficial = | ||
cpuinfo_initialize() && cpuinfo_has_x86_f16c() && !cpuinfo_has_x86_avx512fp16(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess checking cpuinfo_has_x86_avx512fp16
is not necessary since onednn (mkldnn) won't use avx512fp16 to compute gemms by default because the avx512fp16 fma would incur accuracy loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jgong5 https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/FP16-GEMM-using-AVX512-on-Sapphire-Rapids/m-p/1570739 doesn't seem to agree. I'm also confused -- I thought FMA was supposed to improve accuracy because it has high internal precision, so the result of the multiply doesn't have to be rounded to FP16 before the addition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@swolchok The link you referred to is about MKL not oneDNN (or mkldnn). MKL has dedicated API (hgemm) that uses AVX512_FP16 instruction but users should be aware of the accuracy loss due to FP16 accumulators. It is not about the accumulator for a single FMA (which has high internal precision has you mentioned) but about accumulation along the K-dim across multiple FMAs. On the other hand, oneDNN uses FP32 accumulators to keep high accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, we use FP32 accumulation as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mkldnn_fp16_gemm
despite the name is also available on ARM, it looks like your change will skip MKLDNN unless it is on x86 platform, wouldn't it? What is the motivation for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mkldnn_fp16_gemm despite the name is also available on ARM,
news to me!
it looks like your change will skip MKLDNN unless it is on x86 platform, wouldn't it?
fortunately no, because ARM machines won't pass cpuinfo_has_x86_f16c()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also available on ARM
I don't think so? https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mkldnn/Utils.h#L120
something to look at for BF16 though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, we use FP32 accumulation as well.
Will you remove this cpuinfo_has_x86_avx512fp16()
check then? I don't think it is relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this cpuinfo_has_x86_avx512fp16() check then? I don't think it is relevant.
I am surprised, but you're definitely the authority on this and I don't have a Sapphire Rapids machine to test on. I'll leave a note for posterity though.
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Testing: To check perf, I ran python torchchat.py generate stories110M --dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase. Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Testing: To check perf, I ran python torchchat.py generate stories110M --dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase. Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! [ghstack-poisoned]
This pull request was exported from Phabricator. Differential Revision: D64280688 |
…rchitectures (#138005) Following up on previous rev to use fp16_gemv_trans in gemv, not just gemm-used-for-gemv. Differential Revision: [D64351092](https://our.internmc.facebook.com/intern/diff/D64351092/) Pull Request resolved: #138005 Approved by: https://github.com/malfet ghstack dependencies: #139082, #139083, #137918
No real reason to have the zero-beta restriction, so let's lift it. Testing: intentionally broke new paths locally to verify test coverage existed Differential Revision: [D64407752](https://our.internmc.facebook.com/intern/diff/D64407752/) Pull Request resolved: #138275 Approved by: https://github.com/malfet ghstack dependencies: #139082, #139083, #137918, #138005
Caused by #137918 By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)`
Caused by #137918 By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)` Pull Request resolved: #139491 Approved by: https://github.com/huydhn, https://github.com/Skylion007
This is the first big milestone we've been building towards! (Following rev also hooks this up to actual gemv.) Testing: To check perf, I ran python torchchat.py generate stories110M --dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase. Differential Revision: [D64280688](https://our.internmc.facebook.com/intern/diff/D64280688/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D64280688/)! Pull Request resolved: pytorch#137918 Approved by: https://github.com/malfet ghstack dependencies: pytorch#139082, pytorch#139083
…rchitectures (pytorch#138005) Following up on previous rev to use fp16_gemv_trans in gemv, not just gemm-used-for-gemv. Differential Revision: [D64351092](https://our.internmc.facebook.com/intern/diff/D64351092/) Pull Request resolved: pytorch#138005 Approved by: https://github.com/malfet ghstack dependencies: pytorch#139082, pytorch#139083, pytorch#137918
No real reason to have the zero-beta restriction, so let's lift it. Testing: intentionally broke new paths locally to verify test coverage existed Differential Revision: [D64407752](https://our.internmc.facebook.com/intern/diff/D64407752/) Pull Request resolved: pytorch#138275 Approved by: https://github.com/malfet ghstack dependencies: pytorch#139082, pytorch#139083, pytorch#137918, pytorch#138005
Caused by pytorch#137918 By guarding all cpuinfo use with `!defined(__s390x__ ) && !defined(__powerpc__)` Pull Request resolved: pytorch#139491 Approved by: https://github.com/huydhn, https://github.com/Skylion007
Stack from ghstack (oldest at bottom):
This is the first big milestone we've been building towards!
(Following rev also hooks this up to actual gemv.)
Testing: To check perf, I ran python torchchat.py generate stories110M
--dtype fp16 --device cpu on an x86 machine without AVX512FP16. Observed roughly 5x tokens/sec increase.
Differential Revision: D64280688
NOTE FOR REVIEWERS: This PR has internal Meta-specific changes or comments, please review them on Phabricator!