CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Adding lowering to persistent-tma device kernel for _scaled_mm #142045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/142045
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit da679d2 with merge base 9dffd12 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to specify this as an epilogue or reuse the tma lowering in future pr
"""Defines the grid for persistent kernels.""" | ||
return ( | ||
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), | ||
1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are the extra 1s here, just for my own knowledge ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this has to do w/ how we thread the launch config to triton, without it I get
File "/home/drisspg/meta/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 1079, in run
return launcher(
^^^^^^^^^
File "<string>", line 5, in launcher
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ValueError: not enough values to unpack (expected 3, got 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe, the grid is assumed to be a 3-tuple (or a callable returning one) in Triton grid launches.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ch#142045) # Summary This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices. Limitations: * This implementations does not work with Bias values: https://github.com/pytorch/pytorch/blob/0602676c8df2d1f85b28a16ec650fbfa844145ce/torch/_inductor/kernel/mm_scaled.py#L106 Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates * K dim must be 32 or greater for these to take effect * Gated by a config flag ( currently defaults to Off, maybe should be on) ## Testing We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green: <img width="1680" alt="Screenshot 2024-12-05 at 7 24 07 PM" src="https://github.com/user-attachments/assets/9c520541-d97a-416f-9af7-e68b366ec90f"> ## Follow Ups * Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues * Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work ### Some profiling code I was using Code I am using to iterate w/ ```Python import torch from dataclasses import dataclass from jsonargparse import CLI import logging from pathlib import Path from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function from torchao.float8.inference import ( addmm_float8_unwrapped_inference, preprocess_data, Float8MMConfig, ) from transformer_nuggets.fp8.fp8_matmul import ( matmul_persistent, matmul_tma_persistent, matmul_device_tma_persistent, ) from enum import Enum logging.getLogger("transformer_nuggets").setLevel(logging.INFO) class FP8Kernel(Enum): PERSISTENT = "Persistent" PERSISTENT_TMA = "Persistent-TMA" DEVICE_TMA = "Device-TMA" SCALED_MM = "Scaled-MM" class ScalingStrategy(Enum): PER_TENSOR = "PerTensor" PER_ROW = "PerRow" @DataClass(frozen=True) class ExperimentConfig: M: int K: int N: int scaling_strategy: ScalingStrategy fp8_kernel: FP8Kernel compile: bool def get_fp8_matmul( A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel, ): A_fp8 = A.to(torch.float8_e4m3fn) B_fp8 = B.to(torch.float8_e4m3fn) A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True)) if scaling_strategy == ScalingStrategy.PER_TENSOR: a_scale = torch.tensor(1, device="cuda", dtype=torch.float32) b_scale = torch.tensor(1, device="cuda", dtype=torch.float32) elif scaling_strategy == ScalingStrategy.PER_ROW: a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32) b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T else: raise ValueError(f"Invalid scaling strategy: {scaling_strategy}") assert fp8_kernel == FP8Kernel.SCALED_MM return lambda: addmm_float8_unwrapped_inference( A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True ) def run_matmul(config: ExperimentConfig): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16) B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16) fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel) if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM: fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs") _ = fp8_matmul() return def main(): torch.random.manual_seed(123) # Define your experiment configuration here config = ExperimentConfig( M=8192, K=8192, N=8192, scaling_strategy=ScalingStrategy.PER_TENSOR, fp8_kernel=FP8Kernel.SCALED_MM, compile=True, ) run_matmul(config) if __name__ == "__main__": CLI(main) ``` Pull Request resolved: pytorch#142045 Approved by: https://github.com/eellison
…d_mm ghstack-source-id: 84cbb34 Pull Request resolved: pytorch/pytorch#142045
…ed_mm ghstack-source-id: 21d3f28 Pull Request resolved: pytorch/pytorch#142045
…ed_mm ghstack-source-id: 08bb2d5 Pull Request resolved: pytorch/pytorch#142045
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper): 1. The min. hardware and Triton version requirements are met for the TMA support. 2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous). 3. The `config.triton.enable_persistent_tma_matmul` is set to `True`. Additional notes: 1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up. 2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous. 3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time). 4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above. 5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up. 6. This PR is rebased onto and unifies with two related PRs landed previously: #142045 (some infra unification with the persistent+TMA template for _scaled_mm) and #134532 (add possibility to disable prolog fusion for selected choices). 7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`. Pull Request resolved: #142101 Approved by: https://github.com/drisspg, https://github.com/eellison
@@ -216,6 +216,19 @@ def filtered_configs( | |||
else mm_kernel_configs | |||
) | |||
|
|||
persistent_mm_kernel_configs = [ | |||
{"config": (128, 128, 64, 3, 8), "cond": True}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats the 4th arg? one taking value of 3.
Stack from ghstack (oldest at bottom):
Summary
This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.
Limitations:
pytorch/torch/_inductor/kernel/mm_scaled.py
Line 106 in 0602676
Testing
We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:

Follow Ups
Some profiling code I was using
Code I am using to iterate w/
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov