CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN][BUAA] Add infer_symbol_shape for dot, diag #67161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
common::errors::InvalidArgument("ShapeError: The dimension of input " | ||
"tensors X(%u) and Y(%u) are different", | ||
x_rank, | ||
y_rank)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分逻辑修改为 equalcstr,并留个注释强制要求相等
// } | ||
bool DotOpInferSymbolicShape(pir::Operation *op, | ||
pir::InferSymbolicShapeContext *infer_context) { | ||
auto x_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,x_shape_or_data
"ShapeError: The dimensions of input tensor X (%u) should be 1 or 2", | ||
x_rank)); | ||
|
||
auto y_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
// } | ||
bool DiagOpInferSymbolicShape(pir::Operation *op, | ||
pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &x_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand(0));
const auto &x_dims = x_shape_or_data.shape();
这样可以嘛,还是说x_dims
本身这个变量名不符合规范,这个变量名是对应的inferMeta
函数里的变量名,使用x_shape
嘛,但是这里的话x_shape
使用的是x_shape。
|
||
if (x_dims.size() <= 1) { | ||
int64_t size_ = | ||
(x_dims.size() == 1UL ? x_dims[0].dyn_cast<int64_t>() : 1L) + offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要先判断一下下标为0、1的DimExpr里的类型再get,如果不是int类型跳过下面的比较分支,使用新符号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不好意思,使用新符号这一块能详细展开一下或者举个例子嘛。
我现在仅仅是是将对应infermeta函数的逻辑移植过来了,所以对于这个新符号可能不知道怎么添加。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是像这个PR里面为SegmentPool的一样吗,即831行上下处?
false, | ||
common::errors::InvalidArgument( | ||
"diag only support 1D/2D matrix, but input has %u dims", | ||
x_dims.size())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是在干嘛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_THROW(common::errors::InvalidArgument(
"The input tensor X's dimensions of DiagV2Op should be either 1 or "
"2, but received %d.",
x_dims.size()));
这一行对应的inferMeta逻辑是这样的。
还是说我该直接使用PADDLE_THROW?
} | ||
// Dot OP require both inputs should have the same shape | ||
|
||
auto x_shape_cut = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以优化为删除原vector的最后一个元素,避免大批拷贝构造
|
||
if (x_shape.size() <= 1) { | ||
symbol::DimExpr size_ = | ||
(x_shape.size() == 1UL ? x_shape[0] : symbol::DimExpr(1)) + offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的offset是不是没取绝对值,-1表示向下偏移
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto x_shape = x_shape_or_data.shape(); | ||
const int offset_data = op->attribute<pir::Int32Attribute>("offset").data(); | ||
auto offset = symbol::DimExpr(offset_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在588行构造匿名对象吧
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto x_shape = x_shape_or_data.shape(); | ||
const int offset_data = op->attribute<pir::Int32Attribute>("offset").data(); | ||
auto offset = symbol::DimExpr(offset_data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要的变量直接删掉吧
PR Category
CINN
PR Types
Improvements
Description
添加
dot
,diag
, 算子符号推导接口实现。