CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
unflatten with specialized graphs per submodule call #137013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/137013
Note: Links to docs will display an error until the docs builds have been completed. âś… You can merge normally! (1 Unrelated Failure)As of commit 964a4f1 with merge base 2b329d3 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D63642479 |
This pull request was exported from Phabricator. Differential Revision: D63642479 |
58d66af
to
015c00c
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
015c00c
to
2c069f4
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
2c069f4
to
57e6e93
Compare
Summary: Pull Request resolved: pytorch#137013 Test Plan: added test Differential Revision: D63642479
test/export/test_export.py
Outdated
@@ -5974,6 +5974,166 @@ def forward(self, x): | |||
|
|||
self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) | |||
|
|||
def test_unflatten_multiple_graphs(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One clarifying question: what would module swapping look like from the user side?
If they don't make distinctions between specialized graphs, it seems like they'd have to manually switch out n, n@1
, and potentially p, p@1
if they're also aliasing? Or if they make distinctions for aliasing/specializations then some subset of those.
I don't have context for what swapping looks like today with aliasing - is it one or multiple swaps? - but the main point is, does this strictly introduce more work for swapping, and could we introduce some unflattener API for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, they will have to update all f@i
variants to new f
. Maybe we should build some convenience APIs to grab these fqn variants.
torch/export/unflatten.py
Outdated
for k, seen_module in self.seen_modules[self.module_id][:-1]: | ||
num_calls[k] = num_calls.get(k, 0) + 1 | ||
seen_child_fqn = _call_name(k, num_calls[k]) | ||
if _check_graph_equivalence(seen_module, self.module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that differently specialized graphs (e.g. N(x, False), N(x, True)
) won't share state? As in if we do attribute swaps on foo.bar
, it won't have the same change for foo.bar@1
if the computation is different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they will share state because the same params / buffer objects have been assigned to all variants. See assign_attr
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intended effect is that these variants are like different methods on the same "moral" instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I meant if someone were to modify params based on the original FQNs, like foo.bar.attr = foo.bar.attr.bfloat16()
, foo.bar@1.attr
won't see the same change? I've debugged some internal FSDP sharding pipelines that'll do this to modify parameter dtype at runtime.
I think this falls under the same category of convenience APIs though, so no big deal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I don't think it will see those changes unless they do it for all variants at the same time, so yeah, need that API. Any suggestions what that API should look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For module swapping we can probably just patch the logic into Angela's _swap_modules()
method in unflatten.py? For attributes probably similar: def _swap_attributes(ep: ExportedProgram, attrs_to_swap: Dict[str, Union[Any, Callable[Any -> Any]]):
57e6e93
to
65214dd
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
65214dd
to
cef6e26
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
cef6e26
to
b335a3b
Compare
78ed1fe
to
2de7009
Compare
Summary: Pull Request resolved: pytorch#137013 Test Plan: added test Reviewed By: pianpwk Differential Revision: D63642479
2de7009
to
d1a438f
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D63642479 |
d1a438f
to
69b5b90
Compare
69b5b90
to
e9c333f
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
e9c333f
to
8728f9e
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D63642479 |
8728f9e
to
f63d2a6
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
f63d2a6
to
dd1f753
Compare
Summary: Pull Request resolved: pytorch#137013 Test Plan: added test Reviewed By: pianpwk Differential Revision: D63642479
dd1f753
to
964a4f1
Compare
This pull request was exported from Phabricator. Differential Revision: D63642479 |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Previously we were making a fairly restrictive assumption when unflattening an exported program: for any submodule, we would assert that the graph of every call to that submodule must be the same. This assertion is load-bearing, i.e., if we simply remove the assertion then we can get incorrect results, as shown by the following example.
However, this goes against the spirit of specializing graphs when exporting: we should expect that for every call to a submodule we might generate a different graph. The goal of this PR is to fix unflattening to handle multiple specialized graphs corresponding to multiple calls to the same submodule.
The idea is simple: for every call to a child module
foo
, we will create potentially different child modulesfoo
,foo@1
,foo@2
, etc. and use those names as targets incallmodule
instructions in the parent graph. An immediate consequence of this is that the list of fqns in an unflattened module may not be the same as an exported module. Note that all these variants share the same parameters / buffers, so that multiple calls to the same submodule can share state as expected.However, as described so far this scheme may end up with needlessly too many submodules. Thus, between calls to the same submodule, if graphs are equal then we optimize away the extra submodules and reuse call names as much as possible. Moreover, when submodules are shared across fqns, we also try to de-duplicate graphs corresponding to their calls as much as possible. Note that no matter what, information about which submodule was called is still preserved, so that if a submodule has to be swapped with another, one can still find all calls to the former submodule and replace them with calls to the latter.
A note on the choice of naming scheme for call names: instead of generating "sibling" modules
foo@1
,foo@2
, etc. forfoo
, we had considered generating "children" modulesfoo._1
,foo._2
, etc. offoo
. However this can cause spurious cycles when de-duplicating graphs. E.g., suppose thatfoo
is an alias forbar._1
andfoo._1
is an alias forbar
, then we must either introduce a cycle or drop the opportunity to optimize. Another idea would be to makefoo
a dummy module that containsfoo._0
corresponding to the first call, but this necessitates too many changes to existing tests and hurts the common case.Differential Revision: D63642479
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o