CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Fix param and buffer mapping for state_dict when there are state_dict hooks #137609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/137609
Note: Links to docs will display an error until the docs builds have been completed. âś… You can merge normally! (2 Unrelated Failures)As of commit 0468fa1 with merge base 93bbc8a ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D64080561 |
torch/export/_trace.py
Outdated
for spec in sig.output_specs: | ||
if spec.kind in ( | ||
OutputKind.BUFFER_MUTATION, | ||
OutputKind.GRADIENT_TO_PARAMETER, | ||
): | ||
spec.target = param_buffer_table[spec.target] | ||
spec.target = state_dict_keys_map.get(spec.target, spec.target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like for the FQN, we store the post-processed, state_dict key, not necessarily the path where the attribute exists. From the discussion I'm not 100% sure what the semantics are, and if this breaks unflattening? Do you know if export() + unflatten() works for the test case you added?
cc: @angelayi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed in chat, summarize the discussion here:
we can modify the verifier to match against the state_dict without the hook
- add a test for unflattening
- add a context manager, so in export, module’s state_dict hook is removed. So now in verifiers, we are matching against the state_dict without hook. Effectively, we ignore state_dict_hooks in export.
One caveat is then for any model with state_dict hook, one won’t be able to interchange between export_program.state_dict() & mod.state_dict().
export_program.state_dict() still works with it self, but you can’t load a model’s state_dict to exported_program, or vice versa.
torch/export/_trace.py
Outdated
state_dict_keys_map = {} | ||
|
||
if mod._state_dict_hooks or mod._state_dict_pre_hooks: | ||
state_dict_without_hook = mod.state_dict(keep_vars=True, use_hooks=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could add a context manager around this call that removes/puts back the state dict hooks on mod upon enter/exit, that way we don't have to make the nn.Module changes
eda7f9a
to
e52f3f3
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
e52f3f3
to
22a5116
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
22a5116
to
73783e2
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
torch/export/exported_program.py
Outdated
state_dict. | ||
""" | ||
with _disabled_load_state_dict_hooks(mod): | ||
return mod.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm what's the intended usage for this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm what's the intended usage for this function?
This this for weight swapping.
At time T1: you're running export, getting a exported program
At time T2: Some serving service serves an artifact from the exported program
At time T3: there is a recurring training job that just finished and updates the model state that is stored.
At time T4: the serving service is going to pick up the same compiled artifact with the new state that was just updated.
This is used to store the new model state at time T3.
something like,
ep = export(model)
d = exported_program_state_dict(model)
# update ep's state_dict with d.
torch/export/exported_program.py
Outdated
@@ -681,6 +681,28 @@ def _decompose_exported_program( | |||
return exported_program | |||
|
|||
|
|||
@contextmanager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, what I had mind with this was to wrap it around some broad chunk of code in export (maybe _export_func), so that anyone working on export who doesn't know about this issue can just call state_dict()
. That way we also don't have to do the manual construction from named_parameters + named_buffers. But I'm happy with the chunk of code in _trace.py
test/export/test_export.py
Outdated
if new_key != key: | ||
del state_dict[key] | ||
|
||
class CustomModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this test case could match what we were seeing before with more modules? Like if a user is trying to remove a layer out of the state dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this test case could match what we were seeing before with more modules? Like if a user is trying to remove a layer out of the state dict
fixed now.
torch/export/exported_program.py
Outdated
mod._state_dict_pre_hooks = state_dict_pre_hooks | ||
|
||
|
||
def exported_program_state_dict(mod: torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep this private for now, and not in this file... maybe in utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep this private for now, and not in this file... maybe in utils?
moved to utils now.
73783e2
to
71664be
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
71664be
to
9cf1dd3
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pushing this through!
torch/_export/utils.py
Outdated
|
||
|
||
@contextmanager | ||
def _disabled_load_state_dict_hooks(mod: torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
def _disabled_load_state_dict_hooks(mod: torch.nn.Module): | |
def _disable_load_state_dict_hooks(mod: torch.nn.Module): |
torch/_export/utils.py
Outdated
def _exported_program_state_dict(mod: torch.nn.Module): | ||
""" | ||
Given a model, returns a state_dict that's consistent with exported program's | ||
state_dict. | ||
""" | ||
with _disabled_load_state_dict_hooks(mod): | ||
return mod.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think we need this function, we can just tell ppl to directly use the disable hook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, removed now.
… hooks (pytorch#137609) Summary: We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks. For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks. To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict(). Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically. nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy. One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have `mod.layers[3].layer.sa_norm.scale`. But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in `mod.state_dict()` to reflect the static model definition, so we have `mod.state_dict()["layers.3.sa_norm.scale"]`. In this Diff, we change ExportedProgram to populate its state_dict using `named_parameters()` and `named_buffers()` instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy. Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly. In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict. The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from `torch.export.exported_program` instead of `model.state_dict()` if the model has state_dict hooks. The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `torch._export.utils._disabled_load_state_dict_hooks` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state. Example: ``` with _disabled_load_state_dict_hooks(M): state_dict = M.state_dict() ``` If a model doesn't have any state_dict hooks, one can still use `model.state_dict()` for weight swapping, so it's BC. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_for_training_with_state_dict_hooks ``` Reviewed By: angelayi Differential Revision: D64080561
9cf1dd3
to
b88f5ac
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
… hooks (pytorch#137609) Summary: We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks. For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks. To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict(). Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically. nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy. One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have `mod.layers[3].layer.sa_norm.scale`. But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in `mod.state_dict()` to reflect the static model definition, so we have `mod.state_dict()["layers.3.sa_norm.scale"]`. In this Diff, we change ExportedProgram to populate its state_dict using `named_parameters()` and `named_buffers()` instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy. Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly. In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict. The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from `torch.export.exported_program` instead of `model.state_dict()` if the model has state_dict hooks. The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `torch._export.utils._disabled_load_state_dict_hooks` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state. Example: ``` with _disabled_load_state_dict_hooks(M): state_dict = M.state_dict() ``` If a model doesn't have any state_dict hooks, one can still use `model.state_dict()` for weight swapping, so it's BC. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_for_training_with_state_dict_hooks ``` Reviewed By: angelayi, pianpwk Differential Revision: D64080561
b88f5ac
to
0468fa1
Compare
This pull request was exported from Phabricator. Differential Revision: D64080561 |
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Resolve #137540
Summary:
We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks.
For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks.
To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict().
Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically.
nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy.
One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have mod.layers[3].layer.sa_norm.scale.
But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in mod.state_dict() to reflect the static model definition, so we have mod.state_dict()["layers.3.sa_norm.scale"].
In this Diff, we change ExportedProgram to populate its state_dict using named_parameters() and named_buffers() instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy.
Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly.
In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict.
The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from torch.export.exported_program instead of model.state_dict() if the model has state_dict hooks.
The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from
_disabled_load_state_dict_hooks(M)
context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state.If a model doesn't have any state_dict hooks, one can still use model.state_dict() for weight swapping, so it's BC.
Test Plan:
Differential Revision: D64080561