CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] fix save load state_dict #66266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] fix save load state_dict #66266
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -437,6 +439,15 @@ def load_state_dict( | |||
rank_to_files, missing_keys = get_rank_to_files( | |||
path, flat_state_dict, process_group, use_dist | |||
) | |||
|
|||
gloabl_rank_to_files = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_rank_to_files 已经是通过 global_data_files 分析的结果,这里还需要 all_gather_object(gloabl_rank_to_files)么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要all_gather,但是get_rank_to_files中要做一些修改。动半和静半得到的state_dict不同,在静半下,state_dict是部分的,所以的到的necessary_files也是部分的,所以需要对necessary_files做一下all_gather。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -549,9 +566,11 @@ def load_state_dict( | |||
storage_chunk_tensor, src=src_rank, group=process_group | |||
) | |||
else: | |||
tmp_tensor = paddle.assign(cur_chunk_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use tmp_tensor, plz add comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memory held by cur_chunk_tensor may be non-contiguous, and the broadcast API does not support this type of tensor.
) | ||
|
||
if src_rank == item.rank: | ||
# assign value locally | ||
paddle.assign(storage_chunk_tensor, cur_chunk_tensor) | ||
if src_rank == paddle.distributed.get_rank(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about the else branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition src_rank == item.rank
will be satisfied by all ranks, but only one rank needs to perform the assignment operation. Additionally, when src_rank != paddle.distributed.get_rank()
, storage_chunk_tensor
may be None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We’ll add comments in the next PR.
@@ -24,6 +24,7 @@ class LocalTensorMetadata: | |||
|
|||
global_offset: Tuple[int] | |||
local_shape: Tuple[int] | |||
dtype: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In static mode, the state_dict
is incomplete, so the previous code would trigger key error
. Here, the dtype
is stored in advance.
PR Category
Auto Parallel
PR Types
Others
Description
fix save load state_dict