CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[AutoParallel] Add Global mesh and sub mesh reshard function #61796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
const TensorDistAttr& out_dist_attr, | ||
DistTensor* out) override; | ||
|
||
std::string Name() override { return "GlobalToSubPPMeshReshardFunction"; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sub to global
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
const ProcessMesh& in_process_mesh = in_dist_attr.process_mesh(); | ||
const ProcessMesh& out_process_mesh = out_dist_attr.process_mesh(); | ||
|
||
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this condition needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
计划暂时只支持global mesh比sub mesh维度多1的情况,目前看能够cover当前遇到的case,如果后续有更多需求的话,再补充新功能。
|
||
bool SubPPMeshToGlobalReshardFunction::IsSuitable( | ||
const DistTensor& in, const TensorDistAttr& out_dist_attr) { | ||
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this condition needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
if (IsCurRankInMesh(in_process_mesh)) { | ||
const DenseTensor& in_dense_value = in.value(); | ||
std::vector<int64_t>& recv_vec = send2recv_map[cur_global_rank]; | ||
for (int64_t recv_id : recv_vec) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use broadcast to reduce the number of communication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我试试,感谢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过确认,无法使用broadcast,原因是:recv节点上,没法拿到send tensor的shape(因为send tensor在另一个rank上),要使用broadcast,需要先把shape信息也发送过去。所以这里暂时保持使用send recv算子,先跑通流程。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
思考了一下,后续也许可以用global shape + 切分信息推导出local shape,来避免发送shape信息
) # 1.0 | ||
global_input.stop_gradient = False | ||
# forward on mesh0 | ||
input_mesh0 = dist.reshard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if resshard is not called?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前的话不调用会报错,是因为动半在计算的时候,要求算子输入tensor的mesh必须相同;后续打算针对sub mesh的情况,把这个reshard的调用隐藏起来,具体做法是在ConvertAllInputsToDistTensor调用reshard_function。
@@ -14,4 +14,5 @@ collect_srcs( | |||
r_to_x_reshard_function.cc | |||
nd_mesh_reshard_function.cc | |||
same_status_reshard_function.cc | |||
global_and_sub_pp_mesh_reshard_function.cc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个reshard名字这里,是不是不引入pp概念会好一点,因为reshard是通用的,没有并行策略的概念
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
std::vector<ProcessMesh> GetSubPPMesh(const ProcessMesh& process_mesh) { | ||
const std::vector<int64_t>& shape = process_mesh.shape(); | ||
const std::vector<int64_t>& process_ids = process_mesh.process_ids(); | ||
const std::vector<std::string>& dim_names = process_mesh.dim_names(); | ||
int64_t total_process_num = process_ids.size(); | ||
int64_t sub_process_num = total_process_num / shape[0]; | ||
std::vector<int64_t> sub_process_mesh_shape(shape.begin() + 1, shape.end()); | ||
std::vector<std::string> sub_process_mesh_dim_names(dim_names.begin() + 1, | ||
dim_names.end()); | ||
|
||
std::vector<ProcessMesh> sub_process_meshes; | ||
for (int i = 0; i < shape[0]; ++i) { | ||
int64_t start_position = i * sub_process_num; | ||
int64_t end_position = start_position + sub_process_num; | ||
std::vector<int64_t> sub_process_ids(process_ids.begin() + start_position, | ||
process_ids.begin() + end_position); | ||
|
||
sub_process_meshes.emplace_back(ProcessMesh( | ||
sub_process_mesh_shape, sub_process_ids, sub_process_mesh_dim_names)); | ||
} | ||
return sub_process_meshes; | ||
} | ||
|
||
bool is_sub_mesh(const ProcessMesh& global_mesh, const ProcessMesh& sub_mesh) { | ||
std::vector<ProcessMesh> sub_process_meshes = GetSubPPMesh(global_mesh); | ||
for (const ProcessMesh& mesh : sub_process_meshes) { | ||
if (mesh == sub_mesh) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种通用的方法,或许可以放到reshard_utils里
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == | ||
out_process_mesh.ndim() + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种能通用的支持从三维mesh到一维mesh的转换吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时只支持升降一维的情况,目前能覆盖遇到的case,后面可以根据实际需求再补充功能。
RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.dims_mapping() != | ||
out_dist_attr.dims_mapping()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是应该判断mesh,不应该是dims_mapping,本质上是不想让跨mesh的same status reshard进入s_to_s这个函数吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我确认一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过确认,未修复之前:s_to_s_cross_mesh reshard_function已经判断了mesh不相同,所以需要加一个切分状态不相同的判断,避免进到这个函数内。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Pcard-73145
新增全局mesh到pp子mesh的reshard function,支持动半场景下的pp全局算子输入