CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Inductor][CPP] Fix issue in CPP GEMM Template Prune Tensor #141798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor][CPP] Fix issue in CPP GEMM Template Prune Tensor #141798
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141798
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 8eb1f98 with merge base b7a45db ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc @jianan-gu |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…141798) **Summary** When addressing [issue pytorch#134998](pytorch#134998), we will verify if any node in the current graph shares the same storage as the node we intend to prune. In the implementation, we assumed that when creating the `GraphLowering` in post-grad phase, there would be no `submodules`, and all `get_attr` nodes would correspond to a `torch.Tensor`. However, this assumption proves incorrect when enabling `FlexAttention`. In this scenario, `submodules` are present as `get_attr` node in post-grad phase. For example: ``` V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] class sdpa_score30(torch.nn.Module): V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] def forward(self, arg0_1: "bf16[][]cpu", arg1_1: "i32[][]cpu", arg2_1: "i32[][]cpu", arg3_1: "i32[][]cpu", arg4_1: "i32[][]cpu"): V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] return arg0_1 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] sdpa_score30 = self.sdpa_score30 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] sdpa_mask30 = self.sdpa_mask30 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] flex_attention_30 = torch.ops.higher_order.flex_attention(add_276, index_put_60, index_put_61, sdpa_score30, (_frozen_param293, _frozen_param295, _frozen_param296, _frozen_param297, _frozen_param298, _frozen_param299, _frozen_param300, _frozen_param301, 64, 64, sdpa_mask30), 0.08838834764831843, {'SKIP_MASK_SCORE': True, 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': False}, (), (_frozen_param294,)); add_276 = sdpa_score30 = sdpa_mask30 = None V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] getitem_60: "bf16[1, 32, 1, 128]" = flex_attention_30[0]; flex_attention_30 = None ``` We added an extra check in the implementation to ensure only comparing the `get_attr` node with `torch.Tensor`. It is difficult to reproduce this issue using pure high-order operators. Adding a unit test after pytorch#141453 lands would be more straightforward. Pull Request resolved: pytorch#141798 Approved by: https://github.com/jgong5
Stack from ghstack (oldest at bottom):
Summary
When addressing issue #134998, we will verify if any node in the current graph shares the same storage as the node we intend to prune. In the implementation, we assumed that when creating the
GraphLowering
in post-grad phase, there would be nosubmodules
, and allget_attr
nodes would correspond to atorch.Tensor
. However, this assumption proves incorrect when enablingFlexAttention
. In this scenario,submodules
are present asget_attr
node in post-grad phase. For example:We added an extra check in the implementation to ensure only comparing the
get_attr
node withtorch.Tensor
. It is difficult to reproduce this issue using pure high-order operators. Adding a unit test after #141453 lands would be more straightforward.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov