CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[BUAA][Infer Symbolic Shape]add full_batch_size_like&temporal_shift&squared_l2_norm #67503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 7981320's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
std::transform(shape.begin(), shape.end(), shape_exprs.begin(), [](int dim) { | ||
return symbol::DimExpr(dim); | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接用for循环+emplace_back实现吧
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &x_dims = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
推荐x_shape
bool SquaredL2NormOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
auto dtype = infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
std::vector<symbol::DimExpr> batch_dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,batch_shape
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &x_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const auto &x_dims = x_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上x_shape,命名需要刻意区分开shape和data
common::errors::InvalidArgument( | ||
"Attr(shift_ratio) should be greater than 0, but received %f", | ||
shift_ratio)); | ||
PADDLE_ENFORCE_LT( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shift_ratio 的enforce合并成一个吧
补充描述下单测情况 |
@@ -33,6 +33,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsStrided) | |||
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(AllReduce_) | |||
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Barrier) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BipartiteMatch) | |||
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchSizeLike) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多添加了个op声明
PR Category
CINN
PR Types
Improvements
Description
优化了标题中的三个算子,
其中FullBatchSizeLikeOpInferSymbolicShape需要调用BatchSizeLikeOpInferSymbolicShape
full_batch_size_like缺少单测
temporal_shift和squared_l2_norm有单测