CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add host-side TMA support to AOTInductor #138878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/138878
Note: Links to docs will display an error until the docs builds have been completed. âś… No FailuresAs of commit e417138 with merge base 72ea7ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -251,7 +255,29 @@ def generate_user_defined_triton_kernel( | |||
) | |||
|
|||
def generate_tma_descriptor(self, desc): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because AOTI still uses two-pass run at the moment, does the cpp wrapper codegen here need to read any information from the python run? I have added things like DeferredGpuGridLine
for that purpose. Do you think it is necessary to do something like that here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe, the Python and C++ codegens for the TMA descriptor creation are independent.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds host-side Triton TMA support to AOTInductor. Notes: - Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific). - C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions. - Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen. - This PR concludes the host-side Triton TMA support in PT2. Pull Request resolved: pytorch#138878 Approved by: https://github.com/desertfire, https://github.com/chenyang78 ghstack dependencies: pytorch#138759, pytorch#138877
This adds host-side Triton TMA support to AOTInductor. Notes: - Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific). - C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions. - Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen. - This PR concludes the host-side Triton TMA support in PT2. Pull Request resolved: pytorch#138878 Approved by: https://github.com/desertfire, https://github.com/chenyang78 ghstack dependencies: pytorch#138759, pytorch#138877
Stack from ghstack (oldest at bottom):
This adds host-side Triton TMA support to AOTInductor. Notes:
init1DTMADescriptor
andinit2DTMADescriptor
are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang