CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add host-side Triton TMA support to Inductor #137950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/137950
Note: Links to docs will display an error until the docs builds have been completed. âś… No FailuresAs of commit 0d93a4e with merge base 4a8e493 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good ! would you mind adding a couple symint uses ? Also, would be cool to have a deduping mechanism for tma_descriptor on same tensor
@@ -262,6 +262,9 @@ def generate_user_defined_triton_kernel( | |||
autotune_configs=configs, | |||
) | |||
|
|||
def generate_tma_descriptor(self, desc): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constant_args = [ | ||
*self.dims, | ||
*self.block_dims, | ||
self.element_size, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is dims here actually constant ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily, they can be SymInts. Was following what the UserDefinedTritonKernel
does, where all non-TensorBox
args are put into the constant_args
here.
I was unsure about the semantics of the constant_args
parameter of ExternKernel
. Looking into the code, seems self.constant_args
is mostly used in the codegen-related methods, which are not relevant for TMADescriptor
(neither for UserDefinedTritonKernel
), as the codegen is overridden at the root. Although, I also see it being used as a potential source of unbacked SymInts here. So perhaps I should keep this code as is?
# link back to the underlying tensor in terms of ownership | ||
# to avoid getting the underlying tensor deleted *before* | ||
# the TMADescriptor node can be deleted. | ||
NonOwningLayout(ReinterpretView(tensor, tensor.get_layout())), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this works and I can't think of anything better - but if this comes up as a common pattern maybe we can loop back
You mean test cases? The unit tests run with
TMA descriptors are immutable, so this shouldn't be hard to do: would just need hashing on the underlying tensor and all the args. Let me try. |
block_dims: List[Union[int, torch.SymInt]], | ||
element_size: Optional[int] = None, | ||
): | ||
key = (id(tensor), dims, block_dims, element_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison I'm using id(tensor)
here because TensorBox
happens to be non-hashable. Although this looks correct, but it's very restrictive: we require the same TensorBox
Python object to hit the cache. Is there a more canonical way to do this in Inductor IR (that would allow, e.g., different TensorBox
es referring to the same underlying storage)? Maybe I should unwrap storage before doing id(...)
, or can this ignore offsets in the views which can lead to different data_ptr()
values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could try to cover different Tensors with same strides, same underlying storage, but even then, would require us to fix the layout. And not sure how common that is. I think this is good.
Landing this as the signals look good and all comments resolved. Happy to address further requests in a follow-up PR. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in triton-lang/triton#4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument. - Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR. - `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel). - Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code). - New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA. - AOT Inductor support will be implemented in a follow-up PR. Pull Request resolved: pytorch#137950 Approved by: https://github.com/eellison ghstack dependencies: pytorch#137677
Stack from ghstack (oldest at bottom):
This adds Dynamo tracing support for the host-side Triton TMA API (see
create_2d_tma_descriptor
calls on the host in the Triton tutorial). A few notes:tma_descriptor_metadata
dict is delivered to thetriton_kerenl_wrap_
lowering and passed to their.UserDefinedTritonKernel
as additional argument.tma_descriptor_metadata
,ir.UserDefinedTritonKernel
substitutes the correspondingTensorBox
arguments of the kernel (swapped upstream in Dynamo) by the newir.TMADescriptor
nodes implementing TMA descriptors in Inductor IR.ir.TMADescriptor.__init__
provides the wiring between the upstream underlyingir.TensorBox
and the downstreamir.UserDefinedTritonKernel
kernel. In particular, we useir.NonOwnedLayout
wrappingir.ReinterpretView
to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).ir.TMADescriptor.codegen
, the Triton'screate_{1d,2d}_tma_descriptor
function call is codegened in the wrapper (in the host code).TMADescriptorArg
dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec