CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Dy2St] Accept dataclass
as input argument
#73205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St] Accept dataclass
as input argument
#73205
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds support for using Python dataclass
instances as inputs to Paddle’s dynamic-to-static (dy2static
) API by implementing custom hashing logic and verifying it with a new unit test.
- Introduce
_is_dataclass_instance
and extendmake_hashable
to handle dataclass objects. - Add import of
is_dataclass
from the standard library. - Add a new test to verify that a dataclass wrapping a
Tensor
can be passed throughpaddle.jit.to_static
.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
test/dygraph_to_static/test_dataclass_as_input.py | New unit test to ensure dataclass input is accepted and matches output |
python/paddle/jit/dy2static/utils.py | Import is_dataclass ; add detection and hashing logic for dataclass |
python/paddle/jit/dy2static/utils.py
Outdated
return tuple( | ||
map( | ||
make_hashable, | ||
[ | ||
getattr(x, field_name) | ||
for field_name in x.__dataclass_fields__ | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two different dataclass types with identical field values will produce the same hash, which may cause unintended cache collisions. Consider including the dataclass’s type (e.g. type(x)
or its qualified name) as part of the returned tuple to distinguish between types.
return tuple( | |
map( | |
make_hashable, | |
[ | |
getattr(x, field_name) | |
for field_name in x.__dataclass_fields__ | |
], | |
) | |
return ( | |
type(x).__name__, | |
tuple( | |
map( | |
make_hashable, | |
[ | |
getattr(x, field_name) | |
for field_name in x.__dataclass_fields__ | |
], | |
) | |
), |
Copilot uses AI. Check for mistakes.
python/paddle/jit/dy2static/utils.py
Outdated
make_hashable, | ||
[ | ||
getattr(x, field_name) | ||
for field_name in x.__dataclass_fields__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relying on the internal __dataclass_fields__
dict risks future incompatibilities and implicitly depends on dict ordering. Use dataclasses.fields(x)
to get a stable, documented list of fields in definition order.
Copilot uses AI. Check for mistakes.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #73205 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 1
Lines ? 5
Branches ? 0
===========================================
Hits ? 5
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
在此之前,AST 动转静在根据输入计算 hash 时会依次对输入进行 hash,而
dataclass
本身是unhashable
的,因此需要单独写一下 hash 的方式当然,SOT 是没有这个问题的,但 SOT 有其他的问题(无法识别内部数据,会 hold 内部的 Tensor,有 OOM 的隐患)