CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape No.139,140,141,142】[BUAA] Add generate_proposals,grid_sample,gru,gru_unit op #67413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
// } | ||
bool GenerateProposalsOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
return GenerateProposalsV2OpInferSymbolicShape(op, infer_context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yaml定义里没有GenerateProposalsV2这个op,直接用op名就行
// If bias is used, check its dimensions | ||
if (op->num_operands() > 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为判断是否为null value
auto input_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto &,下同
std::vector<symbol::DimExpr> rpn_rois_shape = {out_unknown_1, | ||
symbol::DimExpr(4)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是batch_size能确定吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的, @gongshaotian 说的kernel代码之后还对变量进行了一次resize,使用的参数就是运行时确定的了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel实现是可能做了点特殊处理,但至少两个out_unknown应该是同一个值
if (!is_test) { | ||
symbol::TensorShapeOrDataDimExprs batch_gate_shape(input_shape); | ||
infer_context->SetShapeOrDataForValue(op->result(0), batch_gate_shape); | ||
|
||
symbol::TensorShapeOrDataDimExprs batch_reset_hidden_prev_shape( | ||
{input_shape[0], frame_size}); | ||
infer_context->SetShapeOrDataForValue(op->result(1), | ||
batch_reset_hidden_prev_shape); | ||
|
||
symbol::TensorShapeOrDataDimExprs batch_hidden_shape( | ||
{input_shape[0], frame_size}); | ||
infer_context->SetShapeOrDataForValue(op->result(2), batch_hidden_shape); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非is_test直接用静态shape赋值接口上新符号吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不太明白此处需要如何修改,是不创建临时变量直接写进SetShapeOrDataForValue函数参数里边吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
每个输出都需要设置symbolic shape,非is_test模式也需要,用infer_context->SetSymbolForValueByStaticShape(xxx)
吧
} else { | ||
infer_context->SetSymbolForValueByStaticShape(op->result(0)); | ||
infer_context->SetSymbolForValueByStaticShape(op->result(1)); | ||
infer_context->SetSymbolForValueByStaticShape(op->result(2)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
恭喜解决了第一个Kernel复用Meta推导结果的算子🎉
bool GenerateProposalsV2OpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
symbol::DimExpr out_unknown_1 = infer_context->GetNextSymName(); | ||
symbol::DimExpr out_unknown_2 = infer_context->GetNextSymName(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从kernel看这两个符号表示的维度是相等的,使用一个即可
// // pass | ||
// return true; | ||
// } | ||
bool GridSampleOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for GridSampleOpInferSymbolicShape
return true; | ||
} | ||
|
||
bool GruUnitOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for GruUnitOpInferSymbolicShape
// // pass | ||
// return true; | ||
// } | ||
bool GenerateProposalsOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for GenerateProposalsOp
infer_context->AddEqualCstr(bias_shape[1], frame_size * 3); | ||
} | ||
|
||
if (!is_test) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
冲突了 |
PR Category
CINN
PR Types
Improvements
Description
#66444
添加generate_proposals,grid_sample,gru,gru_unit算子符号推导接口
gru,gru_unit有op_test但未开启check_pir
generate_proposals关闭check_symbol_infer