CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Add spmd rule No.11 for instance_norm and instance_norm_grad ops. #72938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Add spmd rule No.11 for instance_norm and instance_norm_grad ops. #72938
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
What's your method in which the layout is determined? |
In the kernel of instance_norm (\Paddle\paddle\phi\kernels\cpu\instance_norm_kernel.cc), there are no input layout related parameters, while the dimension "begin_norm_axis" for layer_norm is given in layernorm kernel. But it is certain that in instance_norm, the first two dimensions must be N and C ("NC", "NCL", "NCHW" or "NCDHW"), and only these two dimensions can be sharded, and all other dimensions can be set to Replicated. In the above commit, I made these changes:
|
VLOG(4) << "X" | ||
<< " shape: [" << str_join(x_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; | ||
VLOG(4) << "Scale" | ||
<< " shape: [" << str_join(scale_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(scale_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(scale_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Bias" | ||
<< " shape: [" << str_join(bias_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(bias_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(bias_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Out dims mapping: [" << str_join(y_dist_attr.dims_mapping()) | ||
<< "]"; | ||
VLOG(4) << "Saved_Mean dims mapping: [" | ||
<< str_join(saved_mean_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Saved_Variance dims mapping: [" | ||
<< str_join(saved_variance_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use macro LOG_SPMD_INPUT
or LOG_SPMD_OUTPUT
to simplify log code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will use it to simplify log code in next commit.
VLOG(4) << "Einsum Notation: " << x_axes << "," << scale_axes << "," | ||
<< saved_mean_axes << "," << saved_variance_axes << "," << y_grad_axes | ||
<< "-->" << x_grad_axes << "," << scale_grad_axes << "," | ||
<< bias_grad_axes; | ||
VLOG(4) << "X" | ||
<< " shape: [" << str_join(x_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; | ||
VLOG(4) << "Scale" | ||
<< " shape: [" << str_join(scale_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(scale_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(scale_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Saved_mean" | ||
<< " shape: [" << str_join(saved_mean_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(saved_mean_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(saved_mean_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Saved_variance" | ||
<< " shape: [" << str_join(saved_variance_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(saved_variance_dims_mapping) | ||
<< "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(saved_variance_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Y_grad" | ||
<< " shape: [" << str_join(y_grad_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(y_grad_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(y_grad_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "x_grad dims mapping: [" | ||
<< str_join(x_grad_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Scale_grad dims mapping: [" | ||
<< str_join(scale_grad_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Bias_grad dims mapping: [" | ||
<< str_join(bias_grad_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use macro LOG_SPMD_INPUT
or LOG_SPMD_OUTPUT
to simplify log code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will use it to simplify log code in next commit.
VLOG(4) << "test forward done."; | ||
|
||
// Test backward. | ||
// [-1,0, 1, -1], [-1], [-1,-1], [-1,-1], [-1,0, 1, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also best to annotate the expected output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have added the annotate for the expected output.
d20d811
to
9919b02
Compare
/re-run all-failed |
.pre-commit-config.yaml
Outdated
@@ -56,7 +56,7 @@ repos: | |||
args: [--force-exclude] | |||
# For Python files | |||
- repo: https://github.com/psf/black-pre-commit-mirror | |||
rev: 25.1.0 | |||
rev: 24.4.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要修改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit 的时候做了本地修改,不小心push到PR了,我改回来重新提commit
/re-run all-failed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
请解决下冲突 |
PR Category
Auto Parallel
PR Types
New features
Description
【开源任务】算子切分推导规则开发,支持更多模型使用自动并行,简化更多用户的分布式开发成本
在[N,C,H,W]维度的tensor上对H和W维度做强制为Replicated