CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR][oneDNN] Add fc_onednn_enable_pass #63518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
ed9db35
to
c6b1790
Compare
@@ -95,7 +95,7 @@ | |||
extra_args : str mkldnn_data_type="float32" | |||
|
|||
- op : fc | |||
extra_args : bool ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE=true, bool use_quantizer=false, str mkldnn_data_type="float32", float scale_in=1.0, float[] scale_weights={1.0f}, float scale_out=1.0, bool force_fp32_output=false | |||
extra_args : bool ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE=true, bool use_quantizer=false, str mkldnn_data_type="float32", float scale_in=1.0, float[] scale_weights={1.0f}, float scale_out=1.0, bool force_fp32_output=false, str fuse_activation = "", float fuse_alpha = 0.0, float fuse_beta = 0.0, float fused_output_scale = 1.0f, int[] fused_reshape2_shape = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PIR下不会使用下这个ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE属性,yaml这里以及pass里都给删去吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okk,我是看这个attr在改之前就在了,所以没做多的改动~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以及我看fc已经从fluid op里彻底移出去了,op_compat.yaml里是不是也不需要这个attr了呀?
{"in_num_col_dims", pat.Attr("in_num_col_dims")}, | ||
{"activation_type", pat.Attr("activation_type")}, | ||
{"padding_weights", pat.Attr("padding_weights")}, | ||
{"ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE", res.BoolAttr(true)}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE删去
@@ -635,6 +635,7 @@ const std::vector<std::string> kPirMkldnnPasses{ | |||
"matmul_transpose_reshape_fuse_pass", | |||
"matmul_elementwise_add_fuse_pass", | |||
"matmul_activation_fuse_pass", | |||
"fc_onednn_enable_pass", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该把fc_fuse_pass也加入到kPirMkldnnPasses中吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
想请教一下这里pass的利用顺序,如果是onednn backend的话,是只会跑kPirMkldnnPasses
里的pass吗?如果是会同时跑cpu+onednn pass的话就可以不用加;如果只跑onednn pass的话那就确实得加上~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
想请教一下这里pass的利用顺序,如果是onednn backend的话,是只会跑
kPirMkldnnPasses
里的pass吗?如果是会同时跑cpu+onednn pass的话就可以不用加;如果只跑onednn pass的话那就确实得加上~
刚刚看了一下code是分开执行的~不过这样的话mkldnnpass里是不是应该也要加上conv_bn相关的一些cpu pass?因为我看以前的fluid时期的pass,在cpu下enable_onednn的话,是可以把两边的pass都跑一遍的,所以那时候不需要加fc_fuse_pass这样的pass(除非需要重复执行)。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对的,cpu和mkldnn是分开的,需要维护不同的pass list
|
||
std::string name() const override { return "FcOneDNNEnablePattern"; } | ||
|
||
uint32_t benefit() const override { return benefit_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有一个Pattern,采用默认实现的benefit接口即可,因此也不需要在构造函数中传递一个多余的benefit对象。另外,fc_name_和fused_fc_name可以写死在pattern里,也无需通过pattern的构造函数传进来,因为我看都是固定的取值。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方我之前试着减少变量,但是会报个issue,需要我explictly 声明这个类,我不确定会不会有其他的影响,就没再动这里了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方我之前试着减少变量,但是会报个issue,需要我explictly 声明这个类,我不确定会不会有其他的影响,就没再动这里了
这会儿改掉好像没报错了~
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { | ||
auto act_type = match_ctx.Attr<std::string>("activation_type"); | ||
if (!(act_type == "" || act_type == "relu")) return false; | ||
return true; | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个约束我没有看太懂,可以解释一下吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就是fc这个op有两个有关activation的attr,一个是activation_type
,这个是pd_op.fc也有的attr,只能是relu或空;另一个是onednn.fc的attr叫fuse_activation
,这个是用来fuse其他act的
c6b1790
to
e68bcba
Compare
是否可以等这个 #61925 合入后,再推动合入这个PR?我们把fc_fuse_pass重命名成了matmul_add_act_fuse_pass了。 |
e68bcba
to
1d5df8d
Compare
#63649 参考下这个PR的改动,处理下冲突吧,辛苦~ |
3097ca6
to
3bad205
Compare
@yuanlehome Hi yuanle,可以帮忙看一下这个CINN的CI吗?我看从昨天晚上开始所有的PR都没过。。。应该不是PR本身的问题,thx~ |
PR Category
Others
PR Types
New features
Description
Based on new pass mechanism, here we add pass "fc_onednn_enable_pass" for PIR.
The new pass is same as "fc_mkldnn_pass" in "/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc"
Note:
Due to multiple attributes with vector type might exist, we changed a bit on op_gen.py to avoid duplicate definition.