CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[DistDialect] add python reshard pass in pir #63362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DistDialect] add python reshard pass in pir #63362
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
88f0cc5
to
f31319f
Compare
f31319f
to
5d6cc33
Compare
@@ -30,10 +30,14 @@ namespace dialect { | |||
|
|||
pir::Value shard_tensor(const pir::Value& x, | |||
const phi::distributed::ProcessMesh& process_mesh, | |||
const std::vector<int64_t>& dims_mapping) { | |||
const std::vector<int64_t>& dims_mapping, | |||
const std::vector<int64_t>& partial_dims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shared_tensor的参数是不是和reshard api的参数一致比较好? 感觉一个接收dims_mapping+partial_dims, 一个接收placements, 感觉怪怪的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api.py里面的shard_tensor api的底层似乎也调用了该api, 这个api的接口变了,api.py中的调用也需要适配吧。我看pr中似乎没有修改api.py/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api.py里面的shard_tensor api的底层似乎也调用了该api, 这个api的接口变了,api.py中的调用也需要适配吧。我看pr中似乎没有修改api.py/
已修改shard_tensor参数与reshard一致,thanks!
之前为了快速走通流程,所以用了hack和最简单的方法修改的,后面hack的地方最后一个参数是可选,所以api不需要适配
pir::IrContext* ctx = pir::IrContext::Instance(); | ||
// support amp for shard_tensor in the future | ||
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status; | ||
for (size_t i = 0; i < partial_dims.size(); ++i) { | ||
partial_status[partial_dims[i]] = phi::ReduceType::kRedSum; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿直接硬编码phi::ReduceType::kRedSum是不是有点trick了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉还好,分布式中一般绝大多数都是sum,特殊不是sum 的用户自己再设置
test/auto_parallel/reshard_p_to_r.py
Outdated
initializer=paddle.nn.initializer.Uniform(), | ||
) | ||
|
||
shard_tensor = paddle._pir_ops.shard_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_tensor = paddle._pir_ops.shard_tensor( | |
shard_tensor = paddle._C_ops.shard_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx!
008c092
to
9e547dc
Compare
op_target_dist_attr = op.attrs()[ | ||
"op_dist_attr" | ||
].result_dist_attr(0) | ||
reshard_func = choose_reshard_func( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有 src_dist_attr 和 dst_dist_attr还不足翻译所有的 reshard case,还需要传入其他信息,比如 cur_rank:
同样的 src={replicated(), mesh=[0]}, dst={replicated(), mesh=[1]},
如果 cur_rank=0, 翻译成 send; 如果cur_rank=1, 翻译成 recv。
cur_rank 是choose 传入还是通过全局环境变量内部获取跟动态图对齐。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
choose_reshard_func 的函数签名设计一下,如果之后 动半 的 reshard 逻辑也做到了 python 端,两者如何统一复用这一个接口
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有 src_dist_attr 和 dst_dist_attr还不足翻译所有的 reshard case,还需要传入其他信息,比如 cur_rank: 同样的 src={replicated(), mesh=[0]}, dst={replicated(), mesh=[1]}, 如果 cur_rank=0, 翻译成 send; 如果cur_rank=1, 翻译成 recv。
cur_rank 是choose 传入还是通过全局环境变量内部获取跟动态图对齐。
cur_rank类的信息是静态执行期luanch内部配置,目前是动静统一的获取方式,所以直接在reshard内部获取更合适。此外reshard是继承类,不是每个reshard都需要该参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
choose_reshard_func 的函数签名设计一下,如果之后 动半 的 reshard 逻辑也做到了 python 端,两者如何统一复用这一个接口
初步确认了一下,当前python的接口及函数复用基本没什么问题
"op_dist_attr" | ||
].result_dist_attr(0) | ||
reshard_func = choose_reshard_func( | ||
op_operand_dist_attr, op_target_dist_attr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名和 pir 对齐:operand -- result
op_operand_dist_attr, op_target_dist_attr | ||
) | ||
reshard_func.reshard( | ||
new_program, op, op_operand_dist_attr, op_target_dist_attr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
入参感觉有点冗余? op 应该包含了 op_operand_dist_attr, op_target_dist_attr 信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
入参感觉有点冗余? op 应该包含了 op_operand_dist_attr, op_target_dist_attr 信息
program和op是静态图才有的概念,只传递op情况下reshard接口无法动静统一复用
pir::IrContext* ctx = pir::IrContext::Instance(); | ||
// support amp for shard_tensor in the future | ||
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status; | ||
for (size_t i = 0; i < partial_dims.size(); ++i) { | ||
partial_status[partial_dims[i]] = phi::ReduceType::kRedSum; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉还好,分布式中一般绝大多数都是sum,特殊不是sum 的用户自己再设置
return True | ||
|
||
def reshard( | ||
self, program, op, src_dist_attr, dst_dist_attr, remove_op=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不 remove op 的场景是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不 remove op 的场景是?
这里需要区分是否是reshard op,已经修改remove_op->reshard_op。这是因为不同规则的reshard函数之间有嵌套调用,需要区分是否是reshard_op,reshard_op需要被删除,因此op的插入点前后位置不一样。
paddle.pir.set_insertion_point_after(op) | ||
group = new_process_group(src_mesh.process_ids) | ||
reduced_value = paddle._pir_ops.c_allreduce_sum_( | ||
op_value, group.id, False, False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allreduce 的输入应该是 reshard 的输入,而不是输出
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allreduce 的输入应该是 reshard 的输入,而不是输出
这里op_value是和op类型相关,如果当前op是reshard op输入就是reshard的输入,否则是op的输出,之所以需要这个判断是因为,是否reshard op插入点不一样(见前面reshard_op参数的解释)。
'builtin.parameter', | ||
'pd_op.data', | ||
'dist_op.shard_tensor', | ||
'pd_op.c_allreduce_sum_', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还需要check op 和 输入输出tensor 的dist attr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还需要check op 和 输入输出tensor 的dist attr
已添加
HIDDEN_SIZE = 8 | ||
MP_SIZE = 2 | ||
|
||
with paddle.pir_utils.IrGuard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测默认 用动转静 组网测试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测默认 用动转静 组网测试
已添加动转静组网测试
HIDDEN_SIZE = 8 | ||
MP_SIZE = 2 | ||
|
||
with paddle.pir_utils.IrGuard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测默认 用动转静 组网测试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测默认 用动转静 组网测试
Done
reshard_tensor = paddle._pir_ops.reshard( | ||
shard_tensor, self._out_mesh, [dist.Replicate()] | ||
) | ||
dist_program = apply_reshard_pass_v2(main_program) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要给 通信op 补充 分布式属性,让其能够通过 dist2dense pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要给 通信op 补充 分布式属性,让其能够通过 dist2dense pass
已添加
return None | ||
|
||
|
||
def register_reshard_func(reshard_func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard_func 的分类和命名可以参考一下 torch 和 oneflow(方案设计上,这一版可以先和动半对齐,后续动静统一时可以用新设计方案)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
for op in program.global_block().ops: | ||
if op.name() == "pd_op.send_v2": | ||
op.dist_attr = ( | ||
paddle.base.libpaddle.pir.create_op_dist_attribute( | ||
src_mesh, [src_dist_attr], [] | ||
) | ||
) | ||
elif op.name() == "pd_op.recv_v2": | ||
op.dist_attr = ( | ||
paddle.base.libpaddle.pir.create_op_dist_attribute( | ||
dst_mesh, [], [dst_dist_attr] | ||
) | ||
) | ||
|
||
return recv_value.get_defining_op(), dst_dist_attr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do this?
here dst_dist_attr is actually not necessary, will fix in next pr, thx!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description
Pcard-67164