CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
add backtrack for GetShapeOrDataForValue #63367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -14,6 +14,7 @@ | |||
|
|||
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" | |||
#include <string> | |||
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/pir
目录里的文件不能引用fluid
的东西,这里得想办法处理下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已将该接口移动至pir下
24be6f6
to
1ad324e
Compare
@@ -35,7 +35,9 @@ class IR_API ShapeConstraintIRAnalysis { | |||
|
|||
bool HasShapeOrDataForValue(Value val) const; | |||
|
|||
const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val) const; | |||
void GetAndSetShapeOrDataForValueFromDefiningOp(Value val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数命名是不可以再优化下,比如GetOrInferShapeOrDataForValue
之类的命名?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -37,10 +37,14 @@ DimExprs4ValueT MakeDimExprs4Value( | |||
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager(); | |||
pass_manager->AddPass(pir::CreateShapeOptimizationPass()); | |||
pass_manager->Run(program); | |||
const auto* shape_analysis = | |||
&pir::ShapeAnalysisManager::Instance().Get(program); | |||
auto* shape_analysis = &pir::ShapeAnalysisManager::Instance().Get(program); | |||
return | |||
[shape_analysis](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能把返回值类型改为 std::optional< const symbol::ShapeOrDataDimExprs *>
吗?因为我们没有清楚定义啥叫empty ShapeOrDataDimExprs。
这个MakeDimExprs4Value返回值类型就定义为Opt DimExprs4ValueT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的修改和本身逻辑相耦合,这里只是为了暂时解决bug,已增加TODO待后续PR改进
static symbol::ShapeOrDataDimExprs empty{ | ||
symbol::TensorShapeOrDataDimExprs{}}; | ||
return empty; | ||
} | ||
if (!HasShapeOrDataForValue(val)) { | ||
// backtrack to infer shape from defining op | ||
GetAndSetShapeOrDataForValueFromDefiningOp(val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该需要回溯多个op。
op_a -> tensor_a -> op_b -> tensor_b -> op_c -> tensor_c
当我直接防卫tensor_c对应的符号的时候,这条路径上的op_a、op_b、op_c都得至少推导一次。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们回溯逻辑的基础建立在每个op进行推导时会调用GetShapeOrDataForValue,即调用到op_c的InferSymbolicShape函数时会隐含调用GetShapeOrDataForValue来获取tensor_b的shape,如果这时也无法get到则会递归调用op_b的InferSymbolicShape函数,回溯逻辑隐含在这种调用模式之中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前已将递归回溯逻辑改为迭代实现,规避递归引起的风险,后续PR将进一步完善两种不同函数(带infer功能get和不带infer功能get)的封装
a670ef5
to
c37672c
Compare
} | ||
std::unordered_set<Operation*> subgraph_ops; | ||
std::vector<Operation*> start_ops; | ||
const auto& GetNextVisitOpForBuild = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VisitNotInferedInputOp
const auto& GetPrevVisitOpForInfer = | ||
[&](Operation* op, const std::function<void(Operation*)>& Visit) { | ||
for (auto& operand : op->operands_source()) { | ||
if (operand.impl() && subgraph_ops.count(operand.defining_op())) { | ||
Visit(operand.defining_op()); | ||
} | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数应该不需要写。直接就是VisitNotInferedInputOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
或者换一个名字:VisitSubgraphInputOp
} | ||
} | ||
}; | ||
const auto& GetNextVisitOpForInfer = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VisitSubgraphOutputOp
加锁的逻辑没有梳理清楚,建议暂时都去掉,在最终版本上加锁。 |
PR Category
CINN
PR Types
Improvements
Description
Pcard-67164
This PR adds backtrack logic for the function GetShapeOrDataForValue. With this enhancement, it becomes easier to add passes without having to worry about updating shapes.
Besides, some related changes include:
This PR is only a phased PR, belonging to an upgrade of the dynamic shape update mechanism. The subsequent steps include:
Refactoring the shape analysis class and carefully designing its interfaces.
Integrating the new class with the existing code, removing redundant logic and contradictory usages of the original passes.