CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] enable eliminate common global memory read #65804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] enable eliminate common global memory read #65804
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
auto add = expr.As<ir::Add>(); | ||
ir::Expr lhs_expr = add->a(); | ||
ir::Expr rhs_expr = add->b(); | ||
int lhs_common_factor = ExtractMulNumberFromExpr(lhs_expr); | ||
int rhs_common_factor = ExtractMulNumberFromExpr(rhs_expr); | ||
return cinn::common::AutoSimplify( | ||
ir::Add::Make(Simplify(lhs_expr, ir::Expr(lhs_common_factor)), | ||
Simplify(rhs_expr, ir::Expr(rhs_common_factor)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会有单测走到这个分支吗,这里有明显的正确性问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去除这种simplify,在global转local时加入一些约束条件。
if (IsSymbolicNotEqual(expr1, expr2)) { | ||
return cinn::common::AutoSimplify(ir::Add::Make( | ||
ExtractSymbolicFromExpr(expr1), ExtractSymbolicFromExpr(expr2))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里能举一个实际遇到的 case 吗
看起来这里算出来的表达式,并不能帮助做化简。比如两个 expr 分别是 S0 和 S1,Calculate 得到的结果是 "S0 + S1",经过 Simplify 化简后得到的两个表达式,会变成 "-S1" 和 "-S0"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,这里准备采用更保守的策略,当不同schedule block中的符号都相等时才处理。如果同一个tensor两个expr分别是S0和S1,则不进行Simplify
if (utils::StartsWith(var->name, "S")) { | ||
return ir::ir_utils::IRCopy(expr); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Var 里面有一个成员变量 is_symbolic_constant 用来判断符号常量。记得另提一个 PR fix 此处的逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
Polish code for common global var
from paddle.base import core | ||
|
||
|
||
class GroupNormSubGraph(paddle.nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类名和单测名称保持一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
}; | ||
|
||
void TransformLocalIndicesToIters(ir::Expr* expr) { | ||
TransformLocalIndicesVisitor transformLocalIndicesVisitor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformLocalIndicesVisitor
局部变量我们一般用小写字母,用下划线分割
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
@@ -356,10 +418,117 @@ void EliminateCommonFactorHelper(ir::Expr* expr) { | |||
eliminate_common_factor_visitor(expr); | |||
} | |||
|
|||
class TransformLocalIndicesVisitor : public ir::IRMutator<> { | |||
public: | |||
TransformLocalIndicesVisitor() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认构造函数应该不用写,编译器会补全
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
void CopyIndiceItersToLocalBuffer( | ||
const std::map<std::string, ir::Expr>& name_to_iter, | ||
std::vector<ir::Expr>* local_buffer_iters) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般会避免往函数参数里面传入指针,因为它某种程度上意味着 inplace 操作,这个函数看起来可以把 local_buffer_iters 作为返回值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改后逻辑比较简单,不作为单独的函数了,Done
void CopyIndiceItersToLocalBuffer( | ||
const std::map<std::string, ir::Expr>& name_to_iter, | ||
std::vector<ir::Expr>* local_buffer_iters) { | ||
std::map<std::size_t, ir::Expr> name_helper; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有必要用 std::map 吗?看起来用 std::vector 就够了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
if (expr.As<ir::Mod>()) { | ||
auto mod = expr.As<ir::Mod>(); | ||
return cinn::common::AutoSimplify( | ||
ir::Mod::Make(Simplify(mod->a(), factor), mod->b())); | ||
} else if (expr.As<ir::Div>()) { | ||
auto div = expr.As<ir::Div>(); | ||
return cinn::common::AutoSimplify( | ||
ir::Div::Make(Simplify(div->a(), factor), div->b())); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
采用现在的方案后,gcd 还需要支持 mod 和 div 的化简吗?哪些 case 会遇到呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确实是冗余的操作,Done, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [CINN] enable eliminate common global memory read * update common factor of local * update recursive simplify * update local index simplify * fix code style * add detailed log * fix symbolic simplify * fix add expr simplify * add constrant for expr can not simplify * add transform from index to iter var * polish code for common global var * fix codestyle * fix codestyle * enable transform * add test for pass * add constraint for select * polish code * update test name, fix code * add local buffer size limit
PR Category
Performance Optimization
PR Types
Others
Description
该pass将索引相同的global tensor替换为local tensor
由于索引化简会遇到复杂问题,目前只保守支持以下情况(后续支持一种化简可以开启一种,最终可以把约束条件去掉)1. 单常数、单变量2. 仅包含加法、仅包含乘法,且叶子节点为(1)其中之一修复由于符号导致的local tensor索引问题,采用保守的策略,仅当一个local tensor在不同schedule block的index均为单个符号的情况才化简。
local表达式索引替换为对应的循环变量处理
目前的方案会根据变换前的 global 索引复制过来,例如
从映射的角度考虑,若 dst 是local buffer,则不需要完全按照原始的 global buffer 的索引,可以变换为
这样的形式有利于后续的分析和化简
目前还不支持的表达式形式,将继续按照 global 方式计算: