CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【CINN】Enable AutoLayoutPass flag in train process #71891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -822,7 +822,14 @@ def pass_fn(forward_program, backward_program, program_name_attr): | |||
|
|||
# TODO(xiongkun) who to transfer the pruning program? | |||
infer_program = self.origin_runnable_program.clone() | |||
if auto_layout_is_enabled(): | |||
# TODO(liujinnan) When CINN can perfectly handle Layout conversion, remove the judgment of whether to enable CINN. | |||
if auto_layout_is_enabled() and not cinn_is_enabled( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意下这里,不会再用 cinn_is_enabled
,直接使用 self._backend.is_cinn()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle.set_flags( | ||
{"FLAGS_cudnn_batchnorm_spatial_persistent": True} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 flag 影响的是运行时还是编译期 pass 呢?是否可以用 paddle.base.framework.flag_guard
限定只影响局部呢?不然会让这个影响扩散到动转静之外
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会影响运行时,已添加 Gurad
bn_flag = os.getenv( | ||
"FLAGS_cudnn_batchnorm_spatial_persistent", True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用 paddle.get_flags("FLAGS_cudnn_batchnorm_spatial_persistent")["FLAGS_cudnn_batchnorm_spatial_persistent"]
,不建议用 os.getenv
访问 C++ flag,可以参考 #71817
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
等下,这个 if 是不是不太对 😂,没命中直接 run_program 也不跑了么?
按我理解是
original_bn_flag = paddle.get_flags("FLAGS_cudnn_batchnorm_spatial_persistent")["FLAGS_cudnn_batchnorm_spatial_persistent"]
bn_flag = True if auto_layout_is_enabled() and not self._backend.is_cinn() else original_bn_flag
with flag_guard("FLAGS_cudnn_batchnorm_spatial_persistent", bn_flag):
_C_ops.run_program(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
push
上去后我也发现那个if
确实不太对劲,已经修改😂original_bn_flag
也需要区分是用户设置的还是默认值,如果用户直接指定false
,guard
中应该与用户期望保持一致,也为false
,因此使用了os.getenv
来区分是否为用户输入,- 因此较新的一版实现改为用户定义就按照用户定义走,否则开启
layout
且未开cinn
,设置为True
,其余情况使用paddle.get_flags
获取默认值
user_bn_flag = os.getenv(
"FLAGS_cudnn_batchnorm_spatial_persistent", None
)
if (
user_bn_flag is None
and auto_layout_is_enabled()
and not self._backend.is_cinn()
):
bn_flag = True
else:
bn_flag = (
user_bn_flag
if user_bn_flag is not None
else paddle.get_flags(
"FLAGS_cudnn_batchnorm_spatial_persistent"
)["FLAGS_cudnn_batchnorm_spatial_persistent"]
)
with paddle.base.framework.flag_guard(
"FLAGS_cudnn_batchnorm_spatial_persistent", bn_flag
):
- 感觉上述代码冗余复杂,不具备美感,因此尝试使用如下方案:
guard_creators = []
if auto_layout_is_enabled() and not self._backend.is_cinn():
# AutoLayoutPass may change layout of bn to NHWC, if not enable `FLAGS_cudnn_batchnorm_spatial_persistent`, it will revert to NCHW. So if the user does not set this Flag, we set it to True.
guard_creators.append(
lambda: paddle.base.framework.flag_guard(
"FLAGS_cudnn_batchnorm_spatial_persistent",
os.getenv("FLAGS_cudnn_batchnorm_spatial_persistent", True),
)
)
with compose_guards(*guard_creators)():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to inform you that 8da3223's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
/re-run sot |
/re-run all-failed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -745,6 +745,36 @@ def is_api_in_module_helper(obj, module_prefix): | |||
return m is not None and m.__name__.startswith(module_prefix) | |||
|
|||
|
|||
def auto_layout_guard(backend, guard_creators): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个为啥要叫 guard
呢?或许应该叫 add_auto_layout_guard
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx,将下一个PR中进行修复。
|
||
|
||
@contextmanager | ||
def train_guards(backend): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run_program_op
在 train 和 eval 模式都会跑,train_guards
是否不太合适?或许应该叫 runtime_guard
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx,将下一个PR中进行修复。
* enable flag * use guard * fix bug * polish code * add paddle.is_compiled_with_cuda() * remove error env flag set * fix invalid use of getenv * enable flag in cinn * polish code * add infer flag * polish code
PR Category
CINN
PR Types
New features
Description
enable_auto_layout_pass_in_inference
控制,测试全部通过后去除该 flag 并与训练统一交由enable_auto_layout_pass
flag进行控制FLAGS_cudnn_batchnorm_spatial_persistent
,为不影响原逻辑,需添加 guardFLAGS_cudnn_batchnorm_spatial_persistent
,但用户仍可手动设置 false 关闭该 flagPcard-67164