CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR] remove xshape for reshape op #66089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
spmd_rule : ReshapeInferSpmdDynamic | ||
kernel : | ||
func : reshape | ||
func : reshape_infer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel动静统一之后应该统一为reshape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -731,7 +731,7 @@ std::shared_ptr<OpStrategy> StrategyForReshapeSymbolic( | |||
<< ", output_shapes: " << utils::Join(output_shapes[0], ", "); | |||
|
|||
std::string tensor_name; | |||
if (pack_args.size() == 4) { | |||
if (pack_args.size() == 4 || pack_args.size() == 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还会有 == 4 的场景么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会了,我之后删掉
if (op->isa<paddle::dialect::ReshapeOp>()) { | ||
match_ctx->BindIrValue(tensors[0]->name(), op->result(0)); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加特判是不是不太合适?
auto xshape_dims = xshape.dims(); | ||
auto x_dims = common::slice_ddim(xshape_dims, 1, xshape_dims.size()); | ||
auto grad_x_tmp = reshape<T>(grad_out, common::vectorize(x_dims)); | ||
auto grad_x_tmp = reshape<T>(grad_out, common::vectorize(xshape_dims)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果xshape是动态维度这里会不会有问题?
@@ -962,6 +962,14 @@ void KernelWithXShapeInferMeta(const MetaTensor& xshape, | |||
dx->share_lod(xshape); | |||
} | |||
|
|||
void ReshapeGradInferMeta(const MetaTensor& xshape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xshape这里应该就是x吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独提PR修改变量名
- op : reshape_grad (reshape2_grad) | ||
inputs: | ||
x : XShape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
反向算子应该不用配置这个
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* remove xshape for reshape op * remove xshape for reshape op * ignore dynamic * update * update * update * update * fix PT and prim bug * update * update * update * update * update * fix fusion pass bug * fix fusion pass bug * refine codestyle * update * update * update * update
PR Category
Operator Mechanism
PR Types
Not User Facing
Description
pcard-67164
pd_op.reshape(等系列算子)有xshape这一输出,记录输入 x 的 meta 信息用于反向输入。
xshape 在 cinn_op.reshpe 没有对应输出,导致一系列后续执行错误。直接去掉 xshape 无法满足 inplace 场景,原始 x 的 meta 信息会丢失。在 cinn_op 中添加 xshape 输出会引发其他问题且不够优雅。
在对inplace做改造后(前序PR:https://github.com/PaddlePaddle/Paddle/pull/65491)移除 reshape op 的 xshape 输出。