CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] Add InferSymbolicShape Interface for add_n_array
op
#69698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -2167,6 +2194,12 @@ bool ArrayWrite_Op::InferSymbolicShape( | |||
out(), | |||
symbol::ShapeOrDataDimExprs{ | |||
symbol::RankedTensorArrayShapeOrDataDimExprs(x_shape)}); | |||
// update array's shape as x's shape. | |||
// TOOD(ooooo) Do not change if shape is set by custom, similar to infer_meta | |||
infer_context->SetShapeOrDataForValue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我自行打开的单测,应该设置了也没用,因为从后面 Value 往前向查找的时候,并不会从 defineop 溯源中产生这个调用链,应该也就不会更新 array shape
x0 = pd_op.create_array(xxx)
x1 = pd_op.array_write_(x0,t1,?)
x2 = array_read(x0,?)
从 x2 开始验证推导结果
main_program
{
(%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"d0",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[10],stop_gradient:[true]} : () -> builtin.tensor<10xf32>
(%1) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"d1",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[10],stop_gradient:[true]} : () -> builtin.tensor<10xf32>
(%2) = "pd_op.full" () {dtype:(pd_op.DataType)int64,place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[1],stop_gradient:[true],value:(Double)0} : () -> builtin.tensor<1xi64>
(%3) = "pd_op.create_array" () {dtype:(pd_op.DataType)float32,stop_gradient:[true]} : () -> pd_op.tensor_array<f32>
(%4) = "pd_op.cast" (%0) {dtype:(pd_op.DataType)float32,stop_gradient:[true]} : (builtin.tensor<10xf32>) -> builtin.tensor<10xf32>
(%5) = "pd_op.array_write_" (%3, %4, %2) {} : (pd_op.tensor_array<f32>, builtin.tensor<10xf32>, builtin.tensor<1xi64>) -> pd_op.tensor_array<f32>
(%6) = "pd_op.increment_" (%2) {stop_gradient:[true],value:(Float)1} : (builtin.tensor<1xi64>) -> builtin.tensor<1xi64>
(%7) = "pd_op.cast" (%1) {dtype:(pd_op.DataType)float32,stop_gradient:[true]} : (builtin.tensor<10xf32>) -> builtin.tensor<10xf32>
(%8) = "pd_op.array_write_" (%3, %7, %2) {} : (pd_op.tensor_array<f32>, builtin.tensor<10xf32>, builtin.tensor<1xi64>) -> pd_op.tensor_array<f32>
(%9) = "pd_op.create_array_like" (%3) {stop_gradient:[true],val:(Float)0} : (pd_op.tensor_array<f32>) -> pd_op.tensor_array<f32>
(%10) = "pd_op.array_read" (%9, %2) {stop_gradient:[true]} : (pd_op.tensor_array<f32>, builtin.tensor<1xi64>) -> builtin.tensor<-1xf32>
(%11) = "pd_op.cast" (%0) {dtype:(pd_op.DataType)float32,stop_gradient:[true]} : (builtin.tensor<10xf32>) -> builtin.tensor<10xf32>
(%12) = "pd_op.array_write_" (%9, %11, %2) {} : (pd_op.tensor_array<f32>, builtin.tensor<10xf32>, builtin.tensor<1xi64>) -> pd_op.tensor_array<f32>
(%13) = "pd_op.increment_" (%2) {stop_gradient:[true],value:(Float)-1} : (builtin.tensor<1xi64>) -> builtin.tensor<1xi64>
(%14) = "pd_op.cast" (%1) {dtype:(pd_op.DataType)float32,stop_gradient:[true]} : (builtin.tensor<10xf32>) -> builtin.tensor<10xf32>
(%15) = "pd_op.array_write_" (%9, %14, %2) {} : (pd_op.tensor_array<f32>, builtin.tensor<10xf32>, builtin.tensor<1xi64>) -> pd_op.tensor_array<f32>
(%16) = "builtin.combine" (%3, %9) {stop_gradient:[true]} : (pd_op.tensor_array<f32>, pd_op.tensor_array<f32>) -> vec[pd_op.tensor_array<f32>,pd_op.tensor_array<f32>]
(%17) = "pd_op.add_n_array" (%16) {stop_gradient:[true]} : (vec[pd_op.tensor_array<f32>,pd_op.tensor_array<f32>]) -> pd_op.tensor_array<f32>
(%18) = "pd_op.array_read" (%17, %2) {stop_gradient:[true]} : (pd_op.tensor_array<f32>, builtin.tensor<1xi64>) -> builtin.tensor<-1xf32>
(%19) = "pd_op.increment_" (%2) {stop_gradient:[true],value:(Float)1} : (builtin.tensor<1xi64>) -> builtin.tensor<1xi64>
(%20) = "pd_op.array_read" (%17, %2) {stop_gradient:[true]} : (pd_op.tensor_array<f32>, builtin.tensor<1xi64>) -> builtin.tensor<-1xf32>
}
return true; | ||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{symbol::NullShapeOrDataDimExpr()}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
设置NullShapeOrDataDimExpr语义上是否合理?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前有的语义都不太合理,新增一个 RankedTensorArrayListShapeOrDataDimExprs 的 Value 符号表示是应该合理的。但是当前只有 combineop 和 addnaary 两者绑定的使用场景。现在并没有新增,而是 addnaary 的推导依赖于 combine op 的输入 Value。这里设置什么是不重要的,目前不会作为其他 Value 的推导依赖,好像设置只是保证了推导的过程不会因为 有符号推导接口的 op result 中 非 fake value 没有 shape_or_data 而抛出异常
@@ -1495,6 +1516,7 @@ std::vector<pir::Type> CreateArrayOp::InferMeta( | |||
|
|||
bool CreateArrayOp::InferSymbolicShape( | |||
pir::InferSymbolicShapeContext *infer_context) { | |||
// TODO(ooooo): Try to use output type's dims to decide. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的 symbol::RankedTensorArrayShapeOrDataDimExprs(std::vectorsymbol::DimExpr{}) 语义也是不明确的,语义上表示的当前 TensorArray 中没有元素,歧义的是也可以代表里面装的是 0D Tensor。在 infer_meta 里面有个 TODO, infer_meta 里设置的是 {0},也有点不一致
一般创建空的 tensorarray ,后面 infer_meta 里 array_write_ 都会重新设置 dims,所以可以尝试 outputtype 的 dim 来更新,本地编译太慢,可以另起pr
infer_context->SetShapeOrDataForValue( | ||
op->result(0), | ||
symbol::ShapeOrDataDimExprs{symbol::NullShapeOrDataDimExpr()}); | ||
return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对,设置有的语义都不合理,先不设置吧(无法处理combine_op的TensorArray类型输出被别的Op使用的情况)
…addlePaddle#69698)" This reverts commit 16dbf47.
…addlePaddle#69698)" This reverts commit 16dbf47.
PR Category
CINN
PR Types
Others
Description
add_n_array
op在 pir 下,add_n_array 的输入是 CombineOp 的输出
Paddle/paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Lines 240 to 247 in 6a4e651