CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Add spmd rule No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. #72918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -2614,7 +2785,7 @@ TEST(Topk, Ctor) { | |||
|
|||
// test forward | |||
// axis = 1 | |||
// [0, 1, -1] -> [0, -1, -1], [0, -1, -1] | |||
// [0, -1, -1, 1],[-1],[-1],[-1],[-1] ->[-1 , -1, -1, 1],[1],[1],[1],[1],[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not modify this annotation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your notice, I will revert my changes in next commit.
const std::string data_format, | ||
const bool use_global_stats, | ||
const bool trainable_statistics) { | ||
return BatchNormInferSpmdBase(x, mean, variance, scale, bias); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need parameter of data_format
in BatchNormInferSpmdBase
?
if user pass data_format="NHWC" or "NLC"
, will right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have fit all the data_format status in the new commit.
const bool is_test, | ||
const bool use_global_stats, | ||
const bool trainable_statistics) { | ||
return BatchNormGradInferSpmdBase(x, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same issue of data_format
as in BatchNormInferSpmdBase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I have fit all the data_format status in the new commit.
paddle/phi/ops/yaml/ops.yaml
Outdated
@@ -5056,6 +5056,7 @@ | |||
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) | |||
infer_meta : | |||
func : BatchNormInferMeta | |||
spmd_rule : SyncBatchNormInferSpmd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the operator of sync_batch_norm_
is used for manual parallelism, and the implementation of operator includes communication, not just a calculation operator. should have spmd rule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think you are right, the operator of sync_batch_norm_ cause different GPUs have different batch, and their mean and variance on device should be communication, the tensor can not be sharded. I will remove the spmd_rule for sync_batch_norm_ in next commit.
VLOG(4) << "Einsum Notation: " << x_axes << "," << mean_axes << "," | ||
<< variance_axes << "," << scale_axes << "," << bias_axes << "-->" | ||
<< out_axes << "," << mean_axes << "," << variance_axes; | ||
VLOG(4) << "X" | ||
<< " shape: [" << str_join(x_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; | ||
VLOG(4) << "Mean" | ||
<< " shape: [" << str_join(mean_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(mean_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(mean_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Variance" | ||
<< " shape: [" << str_join(variance_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(variance_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(variance_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Scale" | ||
<< " shape: [" << str_join(scale_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(scale_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(scale_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Bias" | ||
<< " shape: [" << str_join(bias_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(bias_dims_mapping) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(bias_dist_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Out dims mapping: [" << str_join(out_dist_attr.dims_mapping()) | ||
<< "]"; | ||
VLOG(4) << "Mean_out dims mapping: [" | ||
<< str_join(mean_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Variance_out dims mapping: [" | ||
<< str_join(variance_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Saved_mean dims mapping: [" | ||
<< str_join(mean_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Saved_variance dims mapping: [" | ||
<< str_join(variance_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << "Reserve_space dims mapping: [" | ||
<< str_join(reserve_space_dist_attr.dims_mapping()) << "]"; | ||
VLOG(4) << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use macro LOG_SPMD_INPUT
or LOG_SPMD_OUTPUT
to simplify log code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will use it to simplify log code in next commit.
VLOG(4) << "Einsum Notation: " << x_axes << scale_axes << "," << bias_axes | ||
<< "," << mean_out_axes << "," << variance_out_axes << "," | ||
<< saved_mean_axes << "," << saved_variance_axes << "," | ||
<< "-->" << reserve_space_axes << "," << out_grad_axes; | ||
VLOG(4) << "X" | ||
<< " shape: [" << str_join(x_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; | ||
VLOG(4) << "Mean_out" | ||
<< " shape: [" << str_join(mean_out_shape) << "] " | ||
<< "src_dims_mapping: [" | ||
<< str_join(mean_out.dist_attr().dims_mapping()) << "] " | ||
<< "dst_dims_mapping: [" << str_join(mean_out_attr_dst.dims_mapping()) | ||
<< "]"; | ||
VLOG(4) << "Variance_out" | ||
<< " shape: [" << str_join(variance_out_shape) << "] " | ||
<< "src_dims_mapping: [" | ||
<< str_join(variance_out.dist_attr().dims_mapping()) << "] " | ||
<< "dst_dims_mapping: [" | ||
<< str_join(variance_out_attr_dst.dims_mapping()) << "]"; | ||
VLOG(4) << "Scale" | ||
<< " shape: [" << str_join(scale_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(scale.dist_attr().dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(scale_attr_dst.dims_mapping()) | ||
<< "]"; | ||
VLOG(4) << "Bias" | ||
<< " shape: [" << str_join(bias_shape) << "] " | ||
<< "src_dims_mapping: [" << str_join(bias.dist_attr().dims_mapping()) | ||
<< "] " | ||
<< "dst_dims_mapping: [" << str_join(bias_attr_dst.dims_mapping()) | ||
<< "]"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we use macro LOG_SPMD_INPUT
or LOG_SPMD_OUTPUT
to simplify log code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will use it to simplify log code in next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description
No.13 sync_batch_norm