CARVIEW |
Select Language
HTTP/2 200
date: Sat, 19 Jul 2025 15:15:02 GMT
content-type: text/html; charset=utf-8
vary: X-PJAX, X-PJAX-Container, Turbo-Visit, Turbo-Frame, X-Requested-With,Accept-Encoding, Accept, X-Requested-With
x-repository-download: git clone https://github.com/pytorch/pytorch.git
etag: W/"2d221e445332855242c7d805ecdaa2b6"
cache-control: max-age=0, private, must-revalidate
strict-transport-security: max-age=31536000; includeSubdomains; preload
x-frame-options: deny
x-content-type-options: nosniff
x-xss-protection: 0
referrer-policy: no-referrer-when-downgrade
content-security-policy: default-src 'none'; base-uri 'self'; child-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/; connect-src 'self' uploads.github.com www.githubstatus.com collector.github.com raw.githubusercontent.com api.github.com github-cloud.s3.amazonaws.com github-production-repository-file-5c1aeb.s3.amazonaws.com github-production-upload-manifest-file-7fdce7.s3.amazonaws.com github-production-user-asset-6210df.s3.amazonaws.com *.rel.tunnels.api.visualstudio.com wss://*.rel.tunnels.api.visualstudio.com objects-origin.githubusercontent.com copilot-proxy.githubusercontent.com proxy.individual.githubcopilot.com proxy.business.githubcopilot.com proxy.enterprise.githubcopilot.com *.actions.githubusercontent.com wss://*.actions.githubusercontent.com productionresultssa0.blob.core.windows.net/ productionresultssa1.blob.core.windows.net/ productionresultssa2.blob.core.windows.net/ productionresultssa3.blob.core.windows.net/ productionresultssa4.blob.core.windows.net/ productionresultssa5.blob.core.windows.net/ productionresultssa6.blob.core.windows.net/ productionresultssa7.blob.core.windows.net/ productionresultssa8.blob.core.windows.net/ productionresultssa9.blob.core.windows.net/ productionresultssa10.blob.core.windows.net/ productionresultssa11.blob.core.windows.net/ productionresultssa12.blob.core.windows.net/ productionresultssa13.blob.core.windows.net/ productionresultssa14.blob.core.windows.net/ productionresultssa15.blob.core.windows.net/ productionresultssa16.blob.core.windows.net/ productionresultssa17.blob.core.windows.net/ productionresultssa18.blob.core.windows.net/ productionresultssa19.blob.core.windows.net/ github-production-repository-image-32fea6.s3.amazonaws.com github-production-release-asset-2e65be.s3.amazonaws.com insights.github.com wss://alive.github.com api.githubcopilot.com api.individual.githubcopilot.com api.business.githubcopilot.com api.enterprise.githubcopilot.com; font-src github.githubassets.com; form-action 'self' github.com gist.github.com copilot-workspace.githubnext.com objects-origin.githubusercontent.com; frame-ancestors 'none'; frame-src viewscreen.githubusercontent.com notebooks.githubusercontent.com; img-src 'self' data: blob: github.githubassets.com media.githubusercontent.com camo.githubusercontent.com identicons.github.com avatars.githubusercontent.com private-avatars.githubusercontent.com github-cloud.s3.amazonaws.com objects.githubusercontent.com release-assets.githubusercontent.com secured-user-images.githubusercontent.com/ user-images.githubusercontent.com/ private-user-images.githubusercontent.com opengraph.githubassets.com copilotprodattachments.blob.core.windows.net/github-production-copilot-attachments/ github-production-user-asset-6210df.s3.amazonaws.com customer-stories-feed.github.com spotlights-feed.github.com objects-origin.githubusercontent.com *.githubusercontent.com; manifest-src 'self'; media-src github.com user-images.githubusercontent.com/ secured-user-images.githubusercontent.com/ private-user-images.githubusercontent.com github-production-user-asset-6210df.s3.amazonaws.com gist.github.com; script-src github.githubassets.com; style-src 'unsafe-inline' github.githubassets.com; upgrade-insecure-requests; worker-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/
server: github.com
content-encoding: gzip
accept-ranges: bytes
set-cookie: _gh_sess=jct%2ByfssZ5IkTgY9AwDvI4aaM1aZO3j7f6ES15mNvXqD8figQQPEzI9vWudySv8kPNRc5eTXhUVXVTNiQqo2eY8RUJHj9Neat9rOk2eAAXnyv7LqqvhKaGD0pQk6zRbNYeP%2BkvHqjcIH%2FbXFLq2PTAWZDYlkUj5DvFVSfaEhf9TLVeojnM2axzFAmX8R1nJIq1ndLyVE8dXtHHBxCMpj3jHQPXdjK0lcSGZKElvkp8v9csSzLdITw%2BCqZCfTHNwFs%2B6lgzJdWIDMGUDBD6HD4A%3D%3D--o4pY8FcC35gEyDh%2B--UfUy6gl1LOoa8XkLV7sDTA%3D%3D; Path=/; HttpOnly; Secure; SameSite=Lax
set-cookie: _octo=GH1.1.1749360994.1752938101; Path=/; Domain=github.com; Expires=Sun, 19 Jul 2026 15:15:01 GMT; Secure; SameSite=Lax
set-cookie: logged_in=no; Path=/; Domain=github.com; Expires=Sun, 19 Jul 2026 15:15:01 GMT; HttpOnly; Secure; SameSite=Lax
x-github-request-id: C49C:10A6BA:5F6AB6:778932:687BB675
[Inductor] Remove stride-0 dimensions from more complex block pointer… · pytorch/pytorch@86631ec · GitHub
Copy file name to clipboardExpand all lines: test/inductor/test_torchinductor.py
Copy file name to clipboardExpand all lines: test/inductor/test_torchinductor_strided_blocks.py
Copy file name to clipboardExpand all lines: torch/_inductor/codegen/simd.py
Skip to content
Navigation Menu
{{ message }}
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Commit 86631ec
[Inductor] Remove stride-0 dimensions from more complex block pointers (#135557)
Related issue: #125077
### Feature
Inductor tries to remove dimensions with stride 0 from block pointers. Rather than loading with stride 0, it's more efficient to load a smaller block pointer, then use `tl.broadcast_to` to broadcast it up to the desired size. This already worked for simpler block pointers, but it was disabled for more complex block pointers which used `tl.reshape` to change the dimensionality after loading.
This PR generalizes the approach to work for all block pointers. The idea is to first reshape, adding singleton dimensions, then broadcast those singletons up to something larger, then reshape again to the final output shape. For readability, we emit this code only if it actually does something. Simpler loads will just have `tl.load`.
Here's an example of a complicated kernel that uses `reshape` -> `load` -> `reshape`. (The first reshape is actually the slice `[None,None,:]`).
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Before this PR, we would have stride-0 dimensions:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr1, shape=[8, 1, 8], strides=[8, 0, 0], block_shape=[((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))], order=[2, 1, 0], offsets=[(xoffset // 8), 0, xoffset % 8]), boundary_check=[0], eviction_policy='evict_last'), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Here's a simpler example where we use 2D tiling. In this case we don't actually need the broadcast. The broadcast is implied via a slice adding a new singleton dimension. This code is not changed by this PR, but it's important to know that we don't accidentally insert unnecessary broadcasts.
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 8
xnumel = 8
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
### Test Plan
Added a new expecttest to check the emitted code for broadcast addition. Looking at the test, we can see that stride 0 dimensions are removed. (This test generated the example kernels in the previous section.)
This change also removed a stride-0 dimension in an existing block pointer test. I updated the expected code accordingly.
Bonus: I noticed that the test parametrization for `config.prefer_nd_tiling` wasn't working as intended. It ended up always setting this option to `True`. Fixed it so we get the intended test coverage.
Pull Request resolved: #135557
Approved by: https://github.com/shunting314, https://github.com/jansel
Co-authored-by: Yueming Hao <yhao@meta.com>1 parent 2c5f5e3 commit 86631ecCopy full SHA for 86631ec
File tree
Expand file treeCollapse file tree
4 files changed
+229
-78
lines changedFilter options
- test/inductor
- torch/_inductor/codegen
Expand file treeCollapse file tree
4 files changed
+229
-78
lines changedtest/inductor/test_torchinductor.py
Copy file name to clipboardExpand all lines: test/inductor/test_torchinductor.py+1-1Lines changed: 1 addition & 1 deletion
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
12036 | 12036 |
| |
12037 | 12037 |
| |
12038 | 12038 |
| |
12039 |
| - | |
| 12039 | + | |
12040 | 12040 |
| |
12041 | 12041 |
| |
12042 | 12042 |
| |
|
test/inductor/test_torchinductor_strided_blocks.py
Copy file name to clipboardExpand all lines: test/inductor/test_torchinductor_strided_blocks.py+55-3Lines changed: 55 additions & 3 deletions
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
105 | 105 |
| |
106 | 106 |
| |
107 | 107 |
| |
108 |
| - | |
| 108 | + | |
109 | 109 |
| |
110 | 110 |
| |
111 | 111 |
| |
| |||
176 | 176 |
| |
177 | 177 |
| |
178 | 178 |
| |
179 |
| - | |
| 179 | + | |
180 | 180 |
| |
181 | 181 |
| |
182 | 182 |
| |
| |||
230 | 230 |
| |
231 | 231 |
| |
232 | 232 |
| |
233 |
| - | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
234 | 286 |
| |
235 | 287 |
| |
236 | 288 |
| |
|
torch/_inductor/codegen/simd.py
Copy file name to clipboardExpand all lines: torch/_inductor/codegen/simd.py+15-1Lines changed: 15 additions & 1 deletion
Original file line number | Diff line number | Diff line change | |
---|---|---|---|
| |||
17 | 17 |
| |
18 | 18 |
| |
19 | 19 |
| |
| 20 | + | |
20 | 21 |
| |
21 | 22 |
| |
22 | 23 |
| |
| |||
29 | 30 |
| |
30 | 31 |
| |
31 | 32 |
| |
32 |
| - | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
33 | 39 |
| |
34 | 40 |
| |
35 | 41 |
| |
| |||
41 | 47 |
| |
42 | 48 |
| |
43 | 49 |
| |
| 50 | + | |
44 | 51 |
| |
45 | 52 |
| |
46 | 53 |
| |
| |||
106 | 113 |
| |
107 | 114 |
| |
108 | 115 |
| |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
109 | 123 |
| |
110 | 124 |
| |
111 | 125 |
| |
|
You can’t perform that action at this time.
0 commit comments