CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Add unshard_dtensor api #60272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
if dist_tensor.is_dist() is False: | ||
return dist_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是直接报错比较好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
replicate_placements = [dist.Replicate()] * len(placements) | ||
|
||
r_dist_tensor = reshard(dist_tensor, mesh, replicate_placements) | ||
return paddle.Tensor(r_dist_tensor._local_value()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果dist_tensor是一个EagerParam,这里的行为应该是什么样的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了 EagerParam 的处理
# remove the distributed tensor from dist_context | ||
default_dist_ctx = get_default_distributed_context() | ||
serial_tensor_id = dist_tensor.desc.original_id() | ||
default_dist_ctx._dist_tensors_for_program.pop(serial_tensor_id, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this modification would lead to segment fault if user trying to print program with dist attr !
we should also adapt the str function of dist_tensor (static mode object) where assume that all node in program have dist_attr in dist_context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ok to print the variable after removing it from dist_context. I tested it with the following sample code:
ori_tensor = paddle.static.data(
name="input",
shape=[4, 1024, 512],
dtype='float32',
)
d_tensor = dist.shard_tensor(ori_tensor, self.mesh, [Shard(0)])
default_dist_context = get_default_distributed_context()
dense_tensor = dist.unshard_dtensor(d_tensor)
print(dense_tensor) # output dense var
print(d_tensor) # output dense var
When dist_tensor
is not in dist_context
, str function will not add distributed attribute:
dist_context = get_default_distributed_context()
dist_tensor = dist_context.get_dist_tensor_for_program(self)
if dist_tensor is not None:
var_str += ", {name} = {value}".format(
name="dist_attr", value=dist_tensor
)
return var_str
|
||
def unshard_dtensor(dist_tensor): | ||
""" | ||
Converts a distributed tensor to its original dense tensor. ``unshard_dtensor`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems original dense tensor
is a little bit confusing. Actually, it does 2 things,
- Make it Replicated on all mesh dims
- Convert to Paddle.Tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, modified the doc.
dist_tensor.dist_attr = empty_dist_attr | ||
|
||
# remove the distributed tensor from dist_context | ||
default_dist_ctx = get_default_distributed_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the dist_tensor always in default context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_tensor
will add the dist_tensor to default context. If it's not in default context, default_dist_ctx._dist_tensors_for_program.pop(serial_tensor_id, None)
will do nothing.
def unshard_dtensor(dist_tensor): | ||
""" | ||
Converts a distributed tensor to its original dense tensor. ``unshard_dtensor`` | ||
can be treated as a reverse operation of ``shard_tensor``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if they are pair APIs, unshard_tensor
would be more suitable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only a distributed tensor (dtensor) can be unsharded, and only a dense tensor can be sharded with shard_tensor
api. So, here named it unshard_dtensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* add dynamic part of unshard * add unshard_dtensor api * handle Parameter type in unshard_dtensor
PR types
Function optimization
PR changes
APIs
Description
Pcard-76459
Add
unshard_dtensor
API.unshard_dtensor
converts a distributed tensor to its original dense tensor.