CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Refactor optional graph module into CompiledFxGraphConstants #141897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/141897
Note: Links to docs will display an error until the docs builds have been completed. âś… You can merge normally! (4 Unrelated Failures)As of commit a6f565d with merge base 920e436 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/functorch/test_aotdispatch.py
Outdated
gm = self.cache.get(key) | ||
if gm is not None: | ||
gm = make_boxed_func(gm) | ||
return gm, {} | ||
|
||
def post_compile(self, gm, inputs, cudagraphs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is no longer called after ed's previous refactor
[ghstack-poisoned]
@@ -145,7 +144,7 @@ def cudagraph_post_compile( | |||
example_inputs: Sequence[InputType], | |||
compiled_graph: CompiledFxGraph, | |||
cudagraphs: BoxedBool, | |||
gm: Optional[torch.fx.GraphModule], | |||
constants: Dict[str, torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is good
torch/_inductor/output_code.py
Outdated
@@ -252,6 +251,35 @@ def maybe_realign_inputs( | |||
compiled_graph.current_callable = new_callable | |||
|
|||
|
|||
class CompiledFxGraphConstants: | |||
"""Wrapper class that gets constants from a compiled fx graph""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if this parent class explained the subclass inheritance situation and when this one versus the other got used
test/functorch/test_aotdispatch.py
Outdated
@@ -6720,6 +6722,23 @@ def run(f): | |||
self.assertEqual(out, optout) | |||
|
|||
|
|||
@dataclasses.dataclass | |||
class MockFXGraphCacheOutput(OutputCode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what exactly is this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh this is a one off mock ig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd prefer for this to live in _inductor/output_code.py, as the interface for OutputCode is not settled and likely will change some more as we keep working on it.
torch/_inductor/compile_fx.py
Outdated
config.patch(get_cpp_wrapper_config()) | ||
if config.cpp_wrapper | ||
else contextlib.nullcontext() | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all this reformatting very annoying lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the comments are blocking, feel free to do them separately
@@ -204,6 +205,8 @@ def check_cacheable(gm: torch.fx.GraphModule): | |||
raise BypassAOTAutogradCache( | |||
"Cannot cache a graph with compiled autograd enabled" | |||
) | |||
if torch._inductor.config.freezing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original freezing diff, I checked to see if the gm actually had an frozen params created. Maybe that's a little better? I believe that when the config option is set, freezing is applied unconditionally currently, but maybe there's a future where it's not?
pytorch/torch/_inductor/output_code.py
Line 96 in fe68f61
def has_frozen_params(gm: torch.fx.GraphModule) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thing is, we can't really find that out unless we run AOTAutograd. So at this stage, the best we can do is look at the config.
@@ -356,7 +361,7 @@ def load(self, example_inputs, fx_config: _CompileFxKwargs) -> CompiledFxGraph: | |||
|
|||
# TODO: How come cudagraphs could be None here? | |||
# TODO: How come gm is None here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove these TODOs now?
[ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…#141897) FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things: - It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this. - It refactors the code to be way more clear about the constants we're using and when we're using them. Basically, there are two possible sets of constants we can grab from the compiled fx graph. 1. If freezing is turned off, we save the constants directly in CompiledFxGraph. 2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch. We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed. This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is. The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing. Will add a specific AOTAutogradCache to confirm we bypass freezing. Pull Request resolved: pytorch#141897 Approved by: https://github.com/ezyang, https://github.com/masnesral
…#141897) FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things: - It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this. - It refactors the code to be way more clear about the constants we're using and when we're using them. Basically, there are two possible sets of constants we can grab from the compiled fx graph. 1. If freezing is turned off, we save the constants directly in CompiledFxGraph. 2. If freezing is turned on, we save the *names* of the constants in CompiledFxGraph, and use the runtime GraphModule's actual constant values: we reconstruct them from the saved names + the new graph module from AOTDispatch. We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed. This makes it clear that the gm passed to AOTAutogradCache is *not* part of post compile, only the cache key generated from it is. The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing. Will add a specific AOTAutogradCache to confirm we bypass freezing. Pull Request resolved: pytorch#141897 Approved by: https://github.com/ezyang, https://github.com/masnesral
ghstack-source-id: c34bf8a Pull Request resolved: pytorch/pytorch#141897
Hi @jamesjwu @masnesral, Thanks for your PR. We recently meet a issue #143144 which seems related to this PR (or some related changes before this PR).
This seems not a correct assumption. As a example, in the Inductor lowering phase, we may re-layout some constants since a different kernel might be chosen by max-autotune as in:
In this case, we will add new constant in the CompiledFXGraph but it may not in the GraphModule (we will also delete the original constant which is not used now to save memory). Looking forward to your suggestions for how to resolve this issue. cc @frost-intel @jgong5 |
Stack from ghstack (oldest at bottom):
FXGraphCache supports freezing, but AOTAutogradCache does not. This is due to the fact that when freezing is turned on, instead of using the constants from the graph module that was saved on cache miss, we have to take the constants from the AOTAutograd generated graph module. This PR does two things:
It bypasses AOTAutogradCache when freezing is turned on. We should have always been doing this.
It refactors the code to be way more clear about the constants we're using and when we're using them.
Basically, there are two possible sets of constants we can grab from the compiled fx graph.
We implement two different classes for doing just this: one that has access to the post aotdispatch gm, which supports freezing, and one that doesn't have it, which does not support freezing. Then we construct the wrappers and unwrap the result as needed.
This makes it clear that the gm passed to AOTAutogradCache is not part of post compile, only the cache key generated from it is.
The whole flow is pretty confusing, but hopefully this gives us better types and static information for understanding what the different codepaths are doing.
Will add a specific AOTAutogradCache to confirm we bypass freezing.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov