CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[pir]Adding FusionOp and CinnFusionLoweringPass #60769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pir]Adding FusionOp and CinnFusionLoweringPass #60769
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
7e0aca7
to
077af2f
Compare
ded95b3
to
b445b18
Compare
auto group = RebuildGroup(fusion_op); | ||
// Because the group is rebuilt, the order of group.output_values generated | ||
// by BuildCUDAJITInfo may not be same with the order bound in the yield op, | ||
// so a mapping is required. | ||
std::unordered_map<::pir::Value, size_t> value2id; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的output可以直接做到按位对齐吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的output可以直接做到按位对齐吗?
要对齐的话就得记录下旧group里面output_ops和ops的顺序,感觉不太好搞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Group 内 op list 是有序的,理论上输出的value也可以按照在 op list 中出现的顺序来排列,如果这样处理的话会有什么问题?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Group 内 op list 是有序的,理论上输出的value也可以按照在 op list 中出现的顺序来排列,如果这样处理的话会有什么问题?
这里我再看下,理论上是可以的
argument.AddAttribute( | ||
"op_pattern_kind", | ||
pir::Int32Attribute::get(pir::IrContext::Instance(), op_pattern_kind)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_pattern_kind 这个信息是不是必要的?可以通过op list推出来吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_pattern_kind 这个信息是不是必要的?可以通过op list推出来吗
这个可以通过op list推出来,但这样就得重复计算一遍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里重复计算的开销估计不大,我们可以优先考虑设计的合理性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里重复计算的开销估计不大,我们可以优先考虑设计的合理性
好滴
// FusionOp represents a subgraphs that can be fused. Every GroupOp | ||
// can be lowered to at least one FusionOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
解释里再强调下一个FusionOp对应会生成一个Kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
解释里再强调下一个FusionOp对应会生成一个Kernel
好滴
@@ -0,0 +1,196 @@ | |||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023->2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023->2024
好滴
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_{nullptr}; | ||
}; | ||
|
||
class CinnFusionLoweringPass : public pir::PatternRewritePass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉叫LowerCinnFusionOpPass
更匹配些
原来的CINNGroupLoweringPass
名字应该也需要调整下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉叫
LowerCinnFusionOpPass
更匹配些 原来的CINNGroupLoweringPass
名字应该也需要调整下
好滴
if (FLAGS_cinn_enable_map_expr) { | ||
cinn::adt::TryGenerateMapExprFromGroup(group); | ||
} | ||
|
||
auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加个TODO吧,用一个新的Group数据结构替换现有的这个Group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加个TODO吧,用一个新的Group数据结构替换现有的这个Group
好滴
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
void FusionOp::Build(pir::Builder& builder, | ||
pir::OperationArgument& argument, | ||
const std::vector<pir::Type>& output_types, | ||
const int op_pattern_kind) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_pattern_kind
这个参数没有使用了,可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_pattern_kind
这个参数没有使用了,可以去掉
好滴
@@ -0,0 +1,29 @@ | |||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2023->2024
@@ -289,9 +300,9 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> { | |||
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_{nullptr}; | |||
}; | |||
|
|||
class CinnGroupLoweringPass : public pir::PatternRewritePass { | |||
class LowerCinnGroupOpPass : public pir::PatternRewritePass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里叫lower稍微有点不直观,这个pass主要是在做op group的融合,看是不是可以用 merge
或者 fuse
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里叫lower稍微有点不直观,这个pass主要是在做op group的融合,看是不是可以用
merge
或者fuse
?
好滴,我改成MergeCinnGroupOpPass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名与pass命名上要对应
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件名与pass命名上要对应
好滴
*/ | ||
|
||
std::vector<pir::Value> GetBlockOutsideOutput( | ||
const std::vector<pir::Operation*> op_list) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<pir::Operation*> op_list) { | |
const std::vector<pir::Operation*>& op_list) { |
for (size_t i = 0; i < group->output_values.size(); ++i) { | ||
vec_types.push_back(group->output_values[i].type()); | ||
auto fusion_op = rewriter.Build<cinn::dialect::FusionOp>( | ||
output_types, group->op_pattern_kind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FusionOp 的构建还需要 op_pattern_kind信息么?我看其定义并没有存这个信息。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FusionOp 的构建还需要 op_pattern_kind信息么?我看其定义并没有存这个信息。
是的,这里已经改了,还没push上来~
@@ -289,11 +300,11 @@ class GroupOpPattern : public pir::OpRewritePattern<cinn::dialect::GroupOp> { | |||
std::shared_ptr<pir::ShapeConstraintIRAnalysis> shape_analysis_{nullptr}; | |||
}; | |||
|
|||
class CinnGroupLoweringPass : public pir::PatternRewritePass { | |||
class MergeCinnGroupOpPass : public pir::PatternRewritePass { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从逻辑上来看,此Pass的效果是将一个GroupOp替换为了FusionOp,这里的Pass名称是否可以调整下以更容易理解。比如CinnGroupToFusionPass ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从逻辑上来看,此Pass的效果是将一个GroupOp替换为了FusionOp,这里的Pass名称是否可以调整下以更容易理解。比如CinnGroupToFusionPass ?
好滴,这里我再斟酌下
using cinn::hlir::framework::pir::CompatibleInfo; | ||
|
||
std::vector<pir::Value> GetBlockOutsideInput( | ||
const std::vector<pir::Operation*> op_list) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<pir::Operation*> op_list) { | |
const std::vector<pir::Operation*>& op_list) { |
f904762
to
54cab46
Compare
pir::Block* fusion_block = fusion_op.block(); | ||
|
||
for (auto op : group->ops) { | ||
op->MoveTo(fusion_block, fusion_block->end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move没法处理recompute的情形
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move没法处理recompute的情形
好滴
argument.output_types = output_types; | ||
} | ||
|
||
pir::Block* FusionOp::block() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
严格分离读写。
const pir::Block* FusionOp::block() const;
pir::Block* FusionOp::mut_block() const;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
严格分离读写。
const pir::Block* FusionOp::block() const; pir::Block* FusionOp::mut_block() const;
好滴
return ®ion.front(); | ||
} | ||
|
||
std::vector<pir::Operation*> FusionOp::ops() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
小写的方法名一般只用于getter/setter。如果其中会产生大结构体,就应该写成“动词+名词”形式的普通函数名,比如GetOperators()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
小写的方法名一般只用于getter/setter。如果其中会产生大结构体,就应该写成“动词+名词”形式的普通函数名,比如GetOperators()
好滴
pir::Block* fusion_block = fusion_op.block(); | ||
|
||
for (auto op : group->ops) { | ||
op->MoveTo(fusion_block, fusion_block->end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前Group融合时,recompute处理后会出现同一个op同时被多个group持有的情况,move的话确实会有问题,这里可以先使用group的clone
54cab46
to
42f521b
Compare
4ec1b09
to
4526757
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
In the previous code, CinnGroupLoweringPass took on too many tasks, including dividing the GroupOp into subgraphs that can be fused, merging the subgraphs, and constructing the JitKernelOp. This PR moves the merging of subgraphs and constructing the JitKernelOp to CinnFusionLoweringPass for future development.