CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[buaa]【Infer Symbolic Shape No.143 No.167】Add roipool & gumble_softmax & dequantize_linear& quantize_linear #67275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
const auto &four = symbol::DimExpr(4); | ||
infer_context->AddEqualCstr(rois_dim[1], four); | ||
|
||
if (rois_num.size() > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里判断的应该 boxes_num 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -114,6 +114,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Transpose_) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unbind) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unique) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniqueConsecutive) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(UnchangedCheckAxis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉这行,没有这个OP,UnchangedCheckAxis 是个辅助函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
return UnchangedCheckAxisOpInferSymbolicShape(op, infer_context); | ||
} | ||
|
||
bool UnchangedCheckAxisOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数实现的复用性不强,直接合并到GumbelSoftmaxOp的推导接口里面吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
std::vector<symbol::DimExpr> x_dims = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
const auto &rois_num_shape = rois_num_shape_or_data.shape(); | ||
|
||
int pooled_height = | ||
op->attribute<pir::Int32Attribute>("pooled_height").data(); | ||
int pooled_width = op->attribute<pir::Int32Attribute>("pooled_width").data(); | ||
int output_channels = | ||
op->attribute<pir::Int32Attribute>("output_channels").data(); | ||
PADDLE_ENFORCE_EQ( | ||
x_shape.size(), | ||
4, | ||
phi::errors::InvalidArgument( | ||
"The input data should be a four-dimensional tensor with [N,C,H,W], " | ||
"but received input data with %d dimension", | ||
x_shape.size())); | ||
PADDLE_ENFORCE_EQ(rois_shape.size(), | ||
2, | ||
phi::errors::InvalidArgument( | ||
"rois should be a 2-D LoDTensor with shape (num_rois, " | ||
"4) given as [[x1, y1, x2, y2], ...], but received " | ||
"rois is %d-dimensional LoDTensor", | ||
rois_shape.size())); | ||
const auto &four = symbol::DimExpr(4); | ||
infer_context->AddEqualCstr(rois_shape[1], four); | ||
|
||
if (!rois_num_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
PADDLE_ENFORCE_EQ( | ||
rois_num_shape.size(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用rois_num_shape_or_data.shape()之前就应该判断是否为Null
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是仿照infermeta写的,已修改
PADDLE_ENFORCE_GE(axis, | ||
-static_cast<int>(rank), | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-R, R-1], " | ||
"R is the rank of Input(X).")); | ||
PADDLE_ENFORCE_LT(axis, | ||
static_cast<int>(rank), | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-R, R-1], " | ||
"R is the rank of Input(X).")); | ||
} else if (rank == 0) { | ||
PADDLE_ENFORCE_GE(axis, | ||
-1, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-1, " | ||
"0] when input is 0D Tensor ")); | ||
PADDLE_ENFORCE_LE(axis, | ||
0, | ||
common::errors::InvalidArgument( | ||
"Attr(axis) value should be in range [-1, " | ||
"0] when input is 0D Tensor ")); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool in_range = xx
PADDLE_ENFORCE_EQ(in_range, true)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Outdated
Show resolved
Hide resolved
|
||
out_dims[0] = rois_num_shape[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rois_shape[0]
@@ -1489,7 +1489,116 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( | |||
|
|||
return true; | |||
} | |||
bool RoiPoolOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for RoiPool
return true; | ||
} | ||
|
||
bool QuantizeLinearOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for QuantizeLinear
"Attr(axis) value should be in range [-R, R-1], " | ||
"R is the rank of Input(X).")); | ||
} else if (rank == 0) { | ||
PADDLE_ENFORCE_EQ(axis >= -1 || axis <= 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是 && 吧
|
||
if (rank > 0) { | ||
PADDLE_ENFORCE_EQ( | ||
axis >= -static_cast<int>(rank) || axis < static_cast<int>(rank), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同下
PR Category
CINN
PR Types
improvements
Description
添加roipool & gumble_softmax算子符号推导接口