CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape No.60,61,62】[BUAA] Add index_add,index_put,index_select op #67267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -997,7 +997,65 @@ bool TopPSamplingOpInferSymbolicShape( | |||
// // pass | |||
// return true; | |||
// } | |||
bool IndexSelectOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意排版,上面添个空格
const auto &index_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,x_shape
input_rank - 1, | ||
dim)); | ||
|
||
PADDLE_ENFORCE_EQ(index_rank == 1 || (index_rank == 2 && index_dims[1] == 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimExpr 的约束用 Addequalcstr
index_dims, | ||
index_dims.size())); | ||
|
||
PADDLE_ENFORCE_EQ(index_dims[0] != 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉这个约束
@@ -83,6 +83,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gelu_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Imag) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Increment) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Increment_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexAdd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IndexAdd中能添加 equal constrain ,单独实现一下接口,别放在same 类别下了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
PADDLE_ENFORCE_NE(index_shape.shape()[0].dyn_cast<int64_t>(), | ||
0, | ||
common::errors::InvalidArgument( | ||
"The length of Input(Index) can't be 0.")); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有NE的约束,这里不需要转为int64_t再判断了,直接删掉就好
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
auto x_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
auto index_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
auto add_value_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x_shape_or_data,后面的命名都注意一下,取出来的值可能包含data
|
||
// Set the shape for the output | ||
infer_context->SetShapeOrDataForValue(op->result(0), x_shape); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(x_shape_or_data.shape())}
不要把data也传进去了
|
||
infer_context->SetShapeOrDataForValue(op->result(0), x_shape_or_data); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,把shape取出来单独设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
Improvements
Description
#66444
添加index_add,index_put,index_select算子符号推导接口
index_add_ 无单测文件(index_add有)
index_put 无OpTest单测实现
index_select 有OpTest单测但流水线未执行