CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Prim][CINN] Use cpp flag store prim states to ensure consistency #71837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Prim][CINN] Use cpp flag store prim states to ensure consistency #71837
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses a bug fix in the prim flag handling for CINN by unifying the management of prim states using C++ flags to ensure consistency. Key changes include updating internal API functions (set_prim* functions) to accept an optional print_flag parameter, replacing usages of the legacy check_and_set_prim_all_enabled with a new __check_and_set_prim_all_enabled, and adjusting test cases (including subprocess tests and deprecated tests) to reflect these refactorings.
Reviewed Changes
Copilot reviewed 5 out of 13 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
test/prim/pir_prim/test_pir_prim_flags.py | Added comprehensive subprocess tests to verify prim flag states based on environment variables and API calls. |
python/paddle/base/core.py | Refactored internal prim flag functions, adding a new __check_and_set_prim_all_enabled and modifying API signatures to include an optional print_flag parameter. |
python/paddle/jit/dy2static/utils.py | Commented out calls to the legacy check_and_set_prim_all_enabled to avoid unintended side effects. |
test/deprecated/prim/prim/flags/test_prim_flags_case_deprecated.py | Updated deprecated tests to use the new __check_and_set_prim_all_enabled function. |
test/deprecated/prim/prim/flags/test_prim_flags_deprecated.py (second file) | Reduced redundant prim flag tests and focused on eager prim enabling. |
Files not reviewed (8)
- paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.cc: Language not supported
- paddle/common/flags.cc: Language not supported
- paddle/fluid/prim/utils/static/static_global_utils.cc: Language not supported
- paddle/fluid/prim/utils/static/static_global_utils.h: Language not supported
- paddle/fluid/prim/utils/utils.cc: Language not supported
- paddle/fluid/prim/utils/utils.h: Language not supported
- paddle/fluid/pybind/pybind.cc: Language not supported
- test/prim/pir_prim/CMakeLists.txt: Language not supported
Comments suppressed due to low confidence (3)
test/deprecated/prim/prim/flags/test_prim_flags_case_deprecated.py:57
- Using a private function '__check_and_set_prim_all_enabled' directly in tests may lead to brittleness; consider using a public API or introducing a dedicated testing wrapper for flag synchronization.
core.__check_and_set_prim_all_enabled()
python/paddle/jit/dy2static/utils.py:679
- Ensure that deprecating 'check_and_set_prim_all_enabled' is fully reflected in the API usage; if removal is intended, verify that commenting out this call does not impact the overall prim flag consistency.
# core.check_and_set_prim_all_enabled(True)
python/paddle/base/core.py:587
- In '__check_and_set_prim_all_enabled', the sequential handling of FLAGS_prim_all then FLAGS_prim_forward and FLAGS_prim_backward may lead to conflicting flag states if both group and individual flags are set; please confirm that this ordering is the intended behavior.
prim_bwd_env = os.getenv("FLAGS_prim_backward")
40e173b
to
802dc5a
Compare
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
prim 相关 flag 的数据统一使用 C++ flag,避免数据不同步导致出问题
FLAGS_prim_all
,这是冗余数据,它应该只是FLAGS_prim_forward && FLAGS_prim_backward
,使用时应该重新计算,而不是考虑什么时候怎么同步之类的,不过只是这个 C++ FLAG 被移除,环境变量仍然生效StaticCompositeContext::enable_fwd_prim_
和StaticCompositeContext::enable_bwd_prim_
,使用 flag 可以确保 CINN 那边获取的是一致的FLAGS_prim_all
向 C++ flag 同步一次,确保FLAGS_prim_all
环境变量生效,为了能够支持FLAGS_prim_forward=False && FLAGS_prim_all=True
这种 case,同步时也需要把FLAGS_prim_forward
和FLAGS_prim_backward
同步一下,但,只应该同步这一次,其他地方不允许使用这个 API(__check_and_set_prim_all_enabled
),否则将会导致通过 API 设置的效果失效!本 PR 暂时禁掉了如下两个单测,之前因为机制问题导致对比的两者跑了相同模式,进而隐藏了的精度问题,本 PR 暂时禁用,不阻塞合入
test/ir/pir/cinn/test_cinn_group_norm.py
之前跑的都是拆解过的test/ir/pir/cinn/symbolic/test_sub_graph_chatglm2_5_st.py
之前跑的都是未拆解过的Ecosystem updates
Related links
CINN 切默认系列 PR
build_strategy.build_cinn_pass
and replace it withbackend
option #71815FLAGS_use_cinn
viaos.getenv
#71817