CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Inductor] move block pointer analysis to a new module #141733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/141733
Note: Links to docs will display an error until the docs builds have been completed. âś… You can merge normally! (1 Unrelated Failure)As of commit 8c88053 with merge base 0f3f801 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The CI failure seems unrelated. Will confirm before merging. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary Preparatory refactor for pytorch#137243. This refactors the ModularIndexing block pointer analysis into its own module. That way, we can call it from other places besides Triton codegen. In the parent PR, we will use this to find tiling splits that simplify the indexing. # Test plan Tested by the existing CI. Pull Request resolved: pytorch#141733 Approved by: https://github.com/jansel
Fixes #134277 and #142317. Sub-PRs containing refactors from this one: - #141733 - #141738 - #141751 (based off the former) - #142249 - #142020 - #143135 These refactor PRs should land before the main one. # Feature *Note: to minimize risk, multi-dimensional reductions are gated by the flag `config.triton.tile_reductions`, which defaults to False.* Instead of having a single reduction dimension called `"r"`, we can now support 2D reductions with `"r0_"` and `"r1_"` dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D. Here's an example of a 2D persistent reduction kernel: ``` @triton.jit def triton_per_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr): xnumel = 1 r0_numel = 15 R0_BLOCK: tl.constexpr = 16 r1_numel = 15 R1_BLOCK: tl.constexpr = 16 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1) r0_index = tl.arange(0, R0_BLOCK)[None, :, None] r0_offset = 0 r0_mask = r0_index < r0_numel r1_index = tl.arange(0, R1_BLOCK)[None, None, :] r1_offset = 0 r1_mask = r1_index < r1_numel rnumel = r0_numel * r1_numel RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK roffset = r1_offset + (r0_offset*r1_numel) rindex = r1_index + (r0_index*r1_numel) r0_0 = r0_index r1_1 = r1_index tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[15, 15], strides=[30, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[r0_offset, r1_offset]), boundary_check=[0, 1], padding_option='zero')[None, :, :] tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK]) tmp3 = tl.where(r0_mask & r1_mask, tmp1, 0) tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK]) tmp5 = tl.sum(tmp4, 1)[:, None, None] tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp5, None) ''', device_str='cuda') ``` There are a few main differences between this kernel and what Inductor would generate without this PR. - Instead of an `r`/`RBLOCK` dimension, we have two reduction dimensions: `r0_`/`R0_BLOCK` and `r1_`/`R1_BLOCK`. - There are special size and indexing variables for reductions, which don't directly correspond to any kernel dimension. (`rindex`, `rnumel`, `RBLOCK`, and `roffset`.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code. - We generate the line `tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])` before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood. Here's an example of a looped reduction: ``` @triton.jit def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr): xnumel = 3 r0_numel = 43 r1_numel = 129 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = xindex < xnumel r0_base = tl.arange(0, R0_BLOCK)[None, :, None] r1_base = tl.arange(0, R1_BLOCK)[None, None, :] rnumel = r0_numel * r1_numel RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK rbase = r1_base + (r0_base*r1_numel) x0 = xindex block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[3, 43, 129], strides=[11094, 258, 1], block_shape=[XBLOCK, R0_BLOCK, R1_BLOCK], order=[2, 1, 0], offsets=[xoffset, 0, 0]) _tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 0, tl.float32) for r0_offset in range(0, r0_numel, R0_BLOCK): r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel for r1_offset in range(0, r1_numel, R1_BLOCK): r1_index = r1_offset + r1_base r1_mask = r1_index < r1_numel roffset = r1_offset + (r0_offset*r1_numel) rindex = r1_index + (r0_index*r1_numel) r0_1 = r0_index r1_2 = r1_index tmp0 = tl.load(block_ptr0, boundary_check=[0, 1, 2], padding_option='zero', eviction_policy='evict_first') tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK]) tmp3 = _tmp2 + tmp1 _tmp2 = tl.where(r0_mask & r1_mask & xmask, tmp3, _tmp2) block_ptr0 = tl.advance(block_ptr0, [0, 0, R1_BLOCK]) block_ptr0 = tl.advance(block_ptr0, [0, R0_BLOCK, (-1)*R1_BLOCK*((128 + R1_BLOCK) // R1_BLOCK)]) tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK]) tmp2 = tl.sum(tmp4, 1)[:, None, None] tl.store(tl.make_block_ptr(out_ptr0, shape=[3], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.reshape(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0]) ''', device_str='cuda') ``` In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code: - They calculate indices inside the loop using `r0_base` and `r1_base`. For compatibility with existing codegen, these are collapsed to the 1D variant `rbase`. - Block pointer advancements are more nuanced for multidimensional loops. At the end of each loop body, we emit a `tl.advance` line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a new `self.pointer_advancements` field of the kernel, which categorizes advancements by dimension. The biggest difficulty in implementing this feature was that we represented tiling with a tuple like `(5,2)`. In the existing codebase, the compiler can infer that the reduction dimension of `(5,2)` is `2`, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like `{"x": 5, "r0_": 2, "r1_": 4}`. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like `{"x": 5, "y": 5, "r0_": 2, "r1_": 4}`. (This is not supported today, but we might want to do it eventually.) The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (`"x"`, `"y"`) as is. For reduction kernels, we never tile the `"x"` dimension, and only tile the reduction dimensions (`"r0_"`, `"r1_"`). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files. Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for `argmax` and `var_mean`, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing. To address these cases, this PR adds a new feature to the `config.prefer_nd_tiling` option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which *would* simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible. # Test plan This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like `config.prefer_nd_tiling` and `config.tile_reductions`, so this really only checks that the PR doesn't break 1D reductions. In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions: - `test_2d_reduction_odd_shapes`: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions. - `test_2d_reduce_no_x_dim`: test 2D reductions with no x dimension. - `test_2d_welford_reduction`: test 2D welford reductions with block pointers. - `test_welford_non_block_pointer`: test a 2D welford reduction when block pointer analysis fails. - `test_reduction_multiple_discontiguous_dims`: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D. - `test_2d_reduction_multi_kernel`: test multi kernel autotuning on a 2D softmax kernel. - `test_enable_tiled_reductions`: test that `config.triton.tile_reductions` enables/disables this feature. Pull Request resolved: #137243 Approved by: https://github.com/jansel Co-authored-by: Yueming Hao <yhao@meta.com> Co-authored-by: Jason Ansel <jansel@meta.com>
Summary
Preparatory refactor for #137243. This refactors the ModularIndexing block pointer analysis into its own module. That way, we can call it from other places besides Triton codegen. In the parent PR, we will use this to find tiling splits that simplify the indexing.
Test plan
Tested by the existing CI.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov