CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] support forOp with vectorize #68918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
d423a95
to
e3c4d0d
Compare
9e8e008
to
2fd754c
Compare
3a35b71
to
cc97942
Compare
* float4 temp_1 | ||
* float4 temp_2 = b[i * 4 + j] | ||
* temp_1[0] = temp_2[0] | ||
* temp_1[1] = temp_2[2] | ||
* temp_1[2] = temp_2[2] | ||
* temp_1[3] = temp_2[3] | ||
* temp_0_ptr[0] = temp_1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么不能直接写成 temp_0_ptr[0] = temp_2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temp_1[0] = temp_2[0]
temp_1[1] = temp_2[2]
temp_1[2] = temp_2[2]
temp_1[3] = temp_2[3]
对应:
temp_1.x = temp_2.x
temp_1.y = temp_2.y
temp_1.z = temp_2.z
temp_1.w = temp_2.w
这个例子不是很恰当,应该改成add等不支持vector的操作。
temp_1.x = temp_2.x + temp_3.x
temp_1.y = temp_2.y + temp_3.y
temp_1.z = temp_2.z + temp_3.z
temp_1.w = temp_2.w + temp_3.w
#include <stack> | ||
#include <unordered_set> | ||
#include <vector> | ||
#include "paddle/cinn/adt/map_expr.h" | ||
#include "paddle/cinn/common/cas.h" | ||
#include "paddle/cinn/common/ir_util.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有些头文件是不没用到?比如stack, map_expr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (auto var : indices) { | ||
std::unordered_set<std::string> index_symbols = CollectIndexSymbols(&var); | ||
if (index_symbols.count(loop_var_->name)) return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果var非当前的loop_var_,但也是一个迭代变量,可以算作是Scalar吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般vectorize轴放在最里面,var非当前的loop_var_,for循环index访问是一个scalar变量。
|
||
void Visit(const ir::For *op, ir::Expr *expr) override { | ||
auto *forloop = expr->As<ir::For>(); | ||
if (op->is_vectorized()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op的vectorized信息是在哪里设置的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过schdule配置sch->vectorize(loops[vectorize_axis], vectorize_factor)来控制。
@@ -386,6 +387,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess( | |||
// 4.Apply low level pass | |||
if (i != func_bodies.size() - 1) { | |||
func = optim::Optimize(func, target_, false); | |||
optim::VectorizeForTrans(&(func->body)); | |||
optim::Simplify(&(func->body)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要调用Simplify吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要处理。
例如after vectorize for trans pass后出现unroll展开场景:
vectorized_var_1[(0 + 0)] = ((1.00000000f * vectorized_var[(0 + 0)]) + 1.00000000f)
经过Simplify后处理成:
vectorized_var_1[0] = ((1.00000000f * vectorized_var[0]) + 1.00000000f)
codegen无法处理没有Simplify的场景。
@@ -386,6 +387,8 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess( | |||
// 4.Apply low level pass | |||
if (i != func_bodies.size() - 1) { | |||
func = optim::Optimize(func, target_, false); | |||
optim::VectorizeForTrans(&(func->body)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vectorize是放在Schedule里面合适呢还是Optimize里面?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在Optimize里面比较合适。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
4ff76fe
to
636ce18
Compare
std::unordered_set<std::string> CollectIndexSymbols(Expr *x) { | ||
struct Mutator : public ir::IRMutator<Expr *> { | ||
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } | ||
void Visit(const ir::_Var_ *op, Expr *expr) override { | ||
auto *node = expr->As<ir::_Var_>(); | ||
PADDLE_ENFORCE_NOT_NULL(node, | ||
::common::errors::InvalidArgument( | ||
"Sorry, but the node expr is nullptr")); | ||
symbols_.insert(op->name); | ||
} | ||
|
||
std::unordered_set<std::string> GetSymbols() { return symbols_; } | ||
|
||
private: | ||
std::unordered_set<std::string> symbols_; | ||
}; | ||
|
||
Mutator mutator; | ||
mutator(x); | ||
return std::move(mutator.GetSymbols()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该可以用ir::ir_utils::CollectIRNodes
,不用自己重写Visitor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CollectIRNodes只能遍历出expr,我需要额外的采用for循环去处理expr,获取var的string。都是继承了IRVisitorRequireReImpl,采用CollectIRNodes可能更加绕一些。
c7fdbb9
to
ea54de4
Compare
ea54de4
to
3c9a165
Compare
PR Category
CINN
PR Types
Improvements
Description
Pcard-88155
support forOp vectorize