CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape No.15】[BUAA] Add chunk_eval op #67734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape No.15】[BUAA] Add chunk_eval op #67734
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
PIR 测试问题似乎不是我改的部分。 |
什么意思,pir的流水线不是过了吗 |
补充描述下单测状态 |
const symbol::ShapeOrDataDimExprs &label_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
|
||
PADDLE_ENFORCE_EQ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个部分应该先用enforce判断两个shape的size相等,然后写个循环,给两个shape()中的每个dim都添加个equal约束
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
const symbol::ShapeOrDataDimExprs &inference_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const symbol::ShapeOrDataDimExprs &label_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有区分清楚 shape_or_data和 shape的概念,这里用shape后缀命名的话直接label_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
bool ChunkEvalOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const symbol::ShapeOrDataDimExprs &inference_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同下,没区分清楚概念导致命名不规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
PADDLE_ENFORCE_EQ((inference_shape.shape().size() == 3 && | ||
inference_shape.shape()[2] == symbol::DimExpr(1)) || | ||
inference_shape.shape().size() == 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce里面关于dim的约束放到后面吧,这里只对两个size做约束。
后面补充个分支:
if( inference_shape.shape().size() == 3){
Addequalcstr()
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
if (op->operand_source(2)) { | ||
const symbol::ShapeOrDataDimExprs &seq_length_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上 .shape()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
已补充 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
Improvements
Description
添加 chunk_eval 算子符号推导接口。