CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR] Refine conditional_block op translator #59723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR] Refine conditional_block op translator #59723
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… dev/reconstruct_if_translate
paddle/fluid/framework/new_executor/instruction/select_input_instruction.h
Show resolved
Hide resolved
paddle/fluid/framework/new_executor/instruction/select_input_instruction.h
Show resolved
Hide resolved
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", | ||
(*this)->region(0).size())); | ||
auto &true_last_op = (*this)->region(0).front().back(); | ||
PADDLE_ENFORCE_EQ(true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿应该考虑一下if没有返回值的情况。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, tks~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, commet 单独提PR fix
private: | ||
void copy_tensor(const phi::DenseTensor &lod_tensor, | ||
phi::DenseTensor *out) const { | ||
if (!lod_tensor.IsInitialized()) return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么直接return了,而不是Throw Error?
@@ -1863,6 +1863,110 @@ struct FillConstantTranscriber : public OpTranscriber { | |||
} | |||
}; | |||
|
|||
static std::vector<int64_t> ParseCompatibleShapes( | |||
const std::vector<int64_t>& dim1, const std::vector<int64_t>& dim2) { | |||
IR_ENFORCE(dim1.size() == dim2.size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是可以用PADDLE_ENFORCE?另外想了解下,IR_ENFORCE 抛出的异常显示的python callstack 与框架目前的栈有什么差异,会被pybind层Exception正确映射么?如果能的话,那这里就无所谓了
auto op_info = this->LoopkUpOpInfo(ctx, op_desc); | ||
|
||
std::vector<pir::Value> op_inputs = {}; | ||
auto Mask_name = op_desc.Input("Mask")[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto Mask_name = op_desc.Input("Mask")[0]; | |
auto mask_name = op_desc.Input("Mask")[0]; |
这里命名首字母不需要大写?
op_desc.Type(), | ||
Mask_name); | ||
op_inputs.push_back(param_map->at(Mask_name).value); | ||
for (auto in_name : Input_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto in_name : Input_name) { | |
for (auto& in_name : Input_name) { |
|
||
OpOutputMapping arg_to_idx; | ||
OpOutputTypeList op_output_types; | ||
auto Out_name = op_desc.Output("Out")[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto Out_name = op_desc.Output("Out")[0]; | |
auto& Out_name = op_desc.Output("Out")[0]; |
array_op.out().set_type(type); | ||
return array_op.operation(); | ||
} | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接Throw NotImplement 会不会更好一些?这样下游不用检查这个函数的返回是否有效
for (auto input_name : input_names) { | ||
auto cond_op_cond = op->Input("Cond")[0]; | ||
auto& cond_op_inputs = op->Input("Input"); | ||
for (auto input_name : cond_op_inputs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (auto input_name : cond_op_inputs) { | |
for (auto& input_name : cond_op_inputs) { |
tensor1.dtype(), | ||
tensor2.dtype()); | ||
IR_ENFORCE(tensor1.data_layout() == tensor2.data_layout(), | ||
"The 1st input data_layout %s should be equal to 2ed input " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The 1st input data_layout %s should be equal to 2ed input " | |
"The 1st input data_layout %s should be equal to 2nd input " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
false_branch_inter_ = | ||
new PirInterpreter(place, | ||
{}, | ||
&false_branch_block, | ||
&if_op.false_block(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议改用make_unique
PirInterpreter* true_branch_inter_ = nullptr; | ||
|
||
PirInterpreter* false_branch_inter_; | ||
PirInterpreter* false_branch_inter_ = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
换用unique_ptr表明所有权更好吧
pir::Block* block) override { | ||
VLOG(10) << "[op select_input] start transcribing"; | ||
auto op_info = this->LoopkUpOpInfo(ctx, op_desc); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议在这里添加一句
this->InsertSliceOperationForInput(ctx, param_map, op_desc, input_infos, block); |
保证输入的类型正确
pir::Value mask() { return operand_source(0); } | ||
pir::OpResult out() { return result(0); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
少了input的相关接口
PR types
New features
PR changes
Others
Description
为了保证cond 算子翻译的成功率,本 PR对翻译策略进行优化,采用的原则是:
本PR主要内容包括:
Pcard-67164