CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[ONNX] Support from dynamic_shapes to dynamic_axes when torch.onnx.export(fallback=True) is triggered #139532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support from dynamic_shapes to dynamic_axes when torch.onnx.export(fallback=True) is triggered #139532
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139532
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit bdfe02d with merge base e429a3b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# It doesn not specify input names if it's a tuple | ||
return dynamic_shapes | ||
|
||
sig = _signature(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this call ever raise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when the dynamic_shapes is dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically dynamic_shapes and args, kwargs are exactly the same (every tensor is replaced by a dictionary of dynamic dimensions) but only a subset of it is supported by torchscript. So i would expect the function to fail if the inputs are not flat unless the module is wrapped into a module doing that. But i don't think we should do it for the users before calling the fallback. However an error message suggesting to do so would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failing makes sense, and it's easier to implement. I have tried the following example:
import torch
import onnx
class Foo(torch.nn.Module):
def forward(self, x):
(a0, a1), (b0, b1), (c0, c1, c2) = x
return a0 + a1 + b0 + b1 + c0 + c1 + c2
f = Foo()
inputs = (
(1, 2),
(
torch.randn(4, 4),
torch.randn(4, 4),
),
(
torch.randn(4, 4),
torch.randn(4, 4),
torch.randn(4, 4),
),
)
input_names = ["a", "b", "c", "d", "e", "f", "g"]
dynamic_axes = {
"c": {0: "c_dim_0", 1: "c_dim_1"},
"e": {0: "e_dim_0", 1: "e_dim_1"},
"f": {0: "f_dim_0", 1: "f_dim_1"},
}
torch.onnx.export(f, (inputs,), "nested.onnx", dynamic_axes=dynamic_axes, input_names=input_names, verbose=True)
onnx_model = onnx.load("nested.onnx")
print(onnx_model.graph.input)
I think this is the only way that nested dynamic_axes works, but it's also kind of awkward that model forward input is not the same as in input_names
. Users need to flatten ahead in dynamic_axes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function has been changed to only work when input is not nested (len(onnx inputs) == len(torch inputs)). When it's nested, users are expected to provide the correct names (match model.forward) of dynamic_shapes, or tuple (honestly, I think users would use tuple when the inputs are nested.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add this to the function docstring? Just something that makes it clear when this functions works and when it doesn't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Overall, I think there are a few questions that we need to clarify before adding (1) When args is nested, what is our expectation on input_names. for example: def forward(self, input: tuple[torch.tensor, torch.tensor]) # there are two elements in the tuple
# one to one mapping to forward
input_names = ["input"]
# or flattened to map ONNX nodes
input_names = ["input_a", "input_b"] (2) import torch
import onnx
class Foo(torch.nn.Module):
def forward(self, x):
(a0, a1), (b0, b1), (c0, c1, c2) = x
return a0 + a1 + b0 + b1 + c0 + c1 + c2
f = Foo()
inputs = (
(1, 2),
(
torch.randn(4, 4),
torch.randn(4, 4),
),
(
torch.randn(4, 4),
torch.randn(4, 4),
torch.randn(4, 4),
),
)
input_names = ["a", "b", "c", "d", "e", "f", "g"]
dynamic_axes = {
"c": {0: "c_dim_0", 1: "c_dim_1"},
"e": {0: "e_dim_0", 1: "e_dim_1"},
"f": {0: "f_dim_0", 1: "f_dim_1"},
}
torch.onnx.export(f, (inputs,), "nested.onnx", dynamic_axes=dynamic_axes, input_names=input_names, verbose=True)
onnx_model = onnx.load("nested.onnx")
print(onnx_model.graph.input) Do we want to do this for users? I guess users could hint us by specifying
|
torchscript does actually support dict inputs, but I think it just flattens the dictionary (may recall incorrectly) |
Oh I was thinking onnx does not support dict. I will make a change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it true that when a user renames input, they can either provide the old name (in the forward function) or the new name in dynamic shapes? What if they conflict?
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
I found a bug. hold on merge for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
*, | ||
dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any], | ||
input_names: Sequence[str], | ||
) -> dict[str, Any] | tuple[Any] | list[Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a docstring to clarify when this function will or will not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…port(fallback=True) is triggered (pytorch#139532) Fixes pytorch#139320 ### Summary: #### (1) Add `_rename_dynamic_shapes_with_model_inputs` for dynamic_shapes to play along with input_names * Use model forward signature to rename dynamic_shapes when dynamic_shapes is not nested and dynamic_shapes is directly using the customized name. This solves the issue that torch.export.export expects dynamic_shapes only uses the model input names. * If the dynamic_shapes is nested, we do nothing. #### (2) Add `_from_dynamic_shapes_to_dynamic_axes` for fallback * We flatten dynamic_shapes with leaf defined _pytree.tree_leaves() ~~* If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by `_rename_dynamic_shapes_with_model_inputs` already.~~ * If a dynamic_shapes is provided, input_names is required to assign the names, because dynamic_axes needs it. Pull Request resolved: pytorch#139532 Approved by: https://github.com/justinchuby
…port(fallback=True) is triggered (pytorch#139532) Fixes pytorch#139320 ### Summary: #### (1) Add `_rename_dynamic_shapes_with_model_inputs` for dynamic_shapes to play along with input_names * Use model forward signature to rename dynamic_shapes when dynamic_shapes is not nested and dynamic_shapes is directly using the customized name. This solves the issue that torch.export.export expects dynamic_shapes only uses the model input names. * If the dynamic_shapes is nested, we do nothing. #### (2) Add `_from_dynamic_shapes_to_dynamic_axes` for fallback * We flatten dynamic_shapes with leaf defined _pytree.tree_leaves() ~~* If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by `_rename_dynamic_shapes_with_model_inputs` already.~~ * If a dynamic_shapes is provided, input_names is required to assign the names, because dynamic_axes needs it. Pull Request resolved: pytorch#139532 Approved by: https://github.com/justinchuby
Fixes #139320
Summary:
(1) Add
_rename_dynamic_shapes_with_model_inputs
for dynamic_shapes to play along with input_names(2) Add
_from_dynamic_shapes_to_dynamic_axes
for fallback* If a dynamic_shapes is not nested, and defined in dict. We can use the key as the input_names, since it should be renamed by_rename_dynamic_shapes_with_model_inputs
already.