CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape No.232】Add infer_symbol_shape for StridedSlice #69911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
ExprVec ends = slice_utils::GetExprVecFromData(ends_shape_data); | ||
ExprVec strides = slice_utils::GetExprVecFromData(strides_shape_data); | ||
|
||
std::vector<int64_t> axes_vec = details::GetVectorAttr(op, "axes"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以先尝试添加一个 Int32 的特化模版
规范下 Description |
TestStridedSliceAPI 和 TestStridedSliceTensorArray 的单测问题交由 @kevincheng2 修复 |
return slice_dims; | ||
} | ||
|
||
inline ShapeOrData StridedSliceRawInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StridedSliceRawInferSymbolicShape这个函数为什么放到infer_sym_slice_utils.h文件中?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SliceRawInferSymbolicShape 也放在utils.h里了,放一起容易维护
symbol::TensorShapeOrDataDimExprs(out_dims)}; | ||
}; | ||
|
||
// When `pd.slice` is operating on a tensor which is produced by a `pd.shape` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是 strided_slice 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以放下个PR修改
std::vector<int64_t>{}, | ||
std::vector<int64_t>{}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认一下,infer_flags_raw 和 decrease_axis 默认都是空有问题吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是因为
stride_slice
并未注册这两个参数,导致无法通过GetVectorAttr
得到的StridedSliceRawInferSymbolicShape
复用了slice
的GetDecreasedDims
所以对decrease_axis
这里构造了空向量,而infer_flags
这个参数是一个无效参数,在slice
和stride_slice
均没有用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
原PR:#69290 |
APPROVAL 流水线发现了一些报错信息规范的问题,可以与注释一起提pr修改,连带检查下slice的几个工具函数里是不是报错信息也不规范 |
PR Category
CINN
PR Types
New features
Description
TestStridedSliceAPI
和TestStridedSliceTensorArray
无法通过starts
、ends
和strides
为Inputs
的单测因无法在运行前得到,所以禁用