CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【BUAA】【Infer Symbolic Shape】Add multinomial, nanmedian for CINN compiler #67507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
||
bool NanmedianOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
std::vector<int64_t> axis_list; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用 paddle::dialect::details::GetVectorAttr 获取
} | ||
const auto &x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
auto &x_dim = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- const auto & 2. 命名不规范,x_shape
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
auto &x_dim = x_shape_or_data.shape(); | ||
int64_t x_rank = x_dim.size(); | ||
ExprVec out_dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,out_shape
std::vector<int64_t> formatted_axis; | ||
for (size_t i = 0; i < axis_list.size(); i++) { | ||
if (x_rank == 0) { | ||
infer_context->AddGreatThanOneCstr(axis_list[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- axis转正数后,直接使用AddequalCstr(axis_list[i], DimExpr(0))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么这里添加了GT one 约束
"which dimension = %d. But received axis = %d.", | ||
x_rank, | ||
axis_list[i])); | ||
PADDLE_ENFORCE_GE(axis_list[i], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 合并这两个断言
x_rank, | ||
axis_list[i])); | ||
} | ||
if (axis_list[i] < 0) axis_list[i] += x_rank; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 把axis 转为正数放在for循环前面
} | ||
|
||
auto median_dim = out_dim; | ||
std::string mode; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attribute和operand 的取值统一放到前面去
|
||
auto median_dim = out_dim; | ||
std::string mode; | ||
if (attributes.find("mode") != attributes.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么添加了这么多取之前的判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可选参数可能没有传入值
// } | ||
bool MultinomialOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
ExprVec x_dims = [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一般不绕一圈写个lambda函数取shape()
return dims; | ||
}(); | ||
const auto &int_num_samples = | ||
op->attribute<paddle::dialect::ScalarAttribute>("num_samples").data(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The number of dimensions of the input probability " | ||
"distribution should be > 0, but got %d.", | ||
x_rank)); | ||
PADDLE_ENFORCE_LE(x_rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个enforece 合并成一个更简洁
// // pass | ||
// return true; | ||
// } | ||
if (!int_num_samples.FromTensor()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有这个接口,需要判断sclar参数是从value(可以理解为tensor)还是做为attribute传进来的
if (op->HasAttribute("num_samples")) { | ||
const auto &int_num_samples = | ||
op->attribute<paddle::dialect::ScalarAttribute>("num_samples").data(); | ||
out_dims[x_rank - 1] = symbol::DimExpr(int_num_samples.to<int64_t>()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少从tensor获取num_samples的判断,必要时使用新符号,这样写可能最后一维根本没设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已完成修改
} else if (op->operand_source(1)) { | ||
out_dims[x_rank - 1] = symbol::DimExpr(infer_context->GetNextSymName()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果operand_source(1)有data的话就不用上新符号了,避免判断可以直接用接口:
const auto &num_samples_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(1));
const auto &data_vec = details::GetOrCreateExprVecFromData(num_samples_shape_or_data, infer_context);
out_dims[x_rank - 1] = data_vec[0];
这个接口自动判断是否有data然后返回,没有的话直接就上新符号了
for (size_t i = 0; i < axis_list.size(); i++) { | ||
if (axis_list[i] < 0) { | ||
axis_list[i] += x_rank; | ||
infer_context->AddEqualCstr(axis_list[i], symbol::DimExpr(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个cstr应该放在1730 行,1730的cstr需要删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
Others
Description
multinomial, nanmedian均存在已有op测试,但multinomial未进行check_output测试,开启check_output测试时static_checker报错

需要完善check_output_customized检查机制