CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
refactor shard optimizer to support unification of dynamic and static… #63542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
b64f931
to
7b50abf
Compare
new_placement[self._sharding_mesh_axis] = dist.Replicate() | ||
out_param = dist.reshard(param, param.process_mesh, new_placement) | ||
if in_pir_mode(): | ||
paddle.assign(out_param, param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样 assign 有 D2D 拷贝,而且会开辟一个新显存空间,让静半sharidng 训练时,同一个参数在内存上同时存在:old_replicated_param, new_shard_param, new_replicated_param 3 个副本,增大显存峰值,无法达到sharding opt 降低显存的目标。
需要通过一个方式 share_data_with(new_replicated_param, old_replicated_param)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks!
d630621
to
516e1e4
Compare
0e14447
to
e37b6f7
Compare
295eb07
to
5fbc3cb
Compare
def get_value_placement(value): | ||
dist_attr = value.dist_attr() | ||
assert dist_attr is not None, "Can't get placement for a dense value." | ||
mesh = dist_attr.process_mesh | ||
dims_mapping = dist_attr.dims_mapping | ||
partial_status = dist_attr.partial_status | ||
palcements = to_placements(dims_mapping, mesh, partial_status) | ||
return palcements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add placement
as Value's property in the future to make it dynamic-static consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks!
5959bbe
to
4c7f04f
Compare
self._shard_fn, (ShardingStage1, ShardingStage2) | ||
): | ||
# in pir mode, reshard pass will automatically handle inplace case, so no extra work is required here. | ||
if not isinstance(param, pir.Value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir 自动 inplace,但 reshard 还是要手动调的吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard已经在reshard_pass中,全部调到了最前面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for python/paddle/distributed/__init__.py
PR Category
Auto Parallel
PR Types
New features
Description
Other
pcard-67164