CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape No.85】[buaa]Add multi_dot op #67328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
const symbol::TensorListShapeOrDataDimExprs &input_values = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)) | ||
.dyn_cast<symbol::TensorListShapeOrDataDimExprs>(); | ||
const auto input_num = input_values.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能不用auto就不用auto,直接用int
const auto input_num = input_values.size(); | ||
PADDLE_ENFORCE_GT( | ||
input_num, | ||
static_cast<size_t>(1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面直接用int,省掉这个cast
common::errors::InvalidArgument( | ||
"The number of input tensors in multi_dot op should > 1")); | ||
|
||
const auto n = input_num; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,int类型没必要用auto,影响可读性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要copy一份这个int值
bool is_vector = false; | ||
std::vector<symbol::DimExpr> out_dim; | ||
|
||
auto first_dim = input_values[0].shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,first_value_shape
auto first_dim = input_values[0].shape(); | ||
PADDLE_ENFORCE_LT( | ||
first_dim.size(), | ||
static_cast<size_t>(3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非必要cast
// If the first tensor is 1D of size n view it as a row vector (1, n) | ||
|
||
if (first_dim.size() == 1) { | ||
first_dim = std::vector<symbol::DimExpr>{static_cast<symbol::DimExpr>(1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个cast是在干嘛
is_vector = true; | ||
} | ||
|
||
auto last_dim = input_values[n - 1].shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,非必要auto
auto last_dim = input_values[n - 1].shape(); | ||
PADDLE_ENFORCE_LT( | ||
last_dim.size(), | ||
static_cast<size_t>(3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
// If the last tensor is 1D of size n view it as a column vector (n, 1) | ||
if (last_dim.size() == 1) { | ||
last_dim = std::vector<symbol::DimExpr>{last_dim[0], | ||
static_cast<symbol::DimExpr>(1)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这里无需cast
: std::vector<symbol::DimExpr>{first_dim[0], last_dim[1]}; | ||
} | ||
|
||
auto width = first_dim.at(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非必要auto,这里为什么突然从[] 替换成了at()
for (size_t i = 1; i < n - 1; ++i) { | ||
auto &input_dim = input_values[i].shape(); | ||
PADDLE_ENFORCE_EQ(input_dim.size(), | ||
static_cast<size_t>(2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
Sorry to inform you that 9a0ca88's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR Category
CINN
PR Types
improvements
Description
添加multi dot算子符号推导