CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Add spmd rule No.10 for index_select and index_select_grad ops. #72727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Add spmd rule No.10 for index_select and index_select_grad ops. #72727
Conversation
/re-run all-failed |
@Yeenyeong This pr is ready to be reviewed, Thanks! |
… add_spmd_index_select
… add_spmd_index_select
@Yeenyeong This pr is ready to be reviewed, Thanks! |
… add_spmd_index_select
std::unordered_map<std::string, int64_t> axis_to_dim_map = | ||
ShardingMergeForTensors( | ||
{{x_axes, x_dims_mapping_dst}, {index_axes, index_dims_mapping_dst}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider:
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
x = paddle.to_tensor([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]])
index = paddle.to_tensor([0, 1, 1, 2], dtype='int32')
# if x_dims_mapping_src is [-1, 0] and index_dims_mapping_src is [0], what results should be obtained?
# if x_dims_mapping_src is [-1, 1] and index_dims_mapping_src is [0], what results should be obtained?
out1 = paddle.index_select(x=x, index=index, axis=0)
And we need to add relevant unit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the key points! now if index's shared mesh is conflit with x, it will choose to reshard in x and replicate in index
self.assertEqual( | ||
inferred_output_dist_attrs[0].dims_mapping, [-1, 0, -1] | ||
) | ||
self.assertFalse(inferred_output_dist_attrs[0]._is_partial()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as above described, we need more unit test case to prove right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
when x = paddle.arange(24).reshape([3, 4, 2])
and x_dims_mapping_src is [-1, -1, -1], so x is replicated in every device.
when index = paddle.to_tensor([0, 1, 1, 2], dtype='int32')
and index_dims_mapping_src is [0], so index in device0/1 is [0, 1]
, index in device2/3 is [1, 2]
According to the implemented spmd rules in this PR, when out = paddle.index_select(x=x, index=index, axis=1)
out_dims_mapping is [-1, 0, -1], so out in device0/1 is [[[0, 1], [2, 3]], [[8, 9], [10, 11]], [[16, 17], [18, 19]]]
, out in device2/3 is [[[2, 3], [4, 5]], [[10, 11], [12, 13]], [[18, 19], [20, 21]]]
can it be partial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not in partial status, before i write it with assert_false, maybe the check is redundant, now i delete it~
self.assertFalse(inferred_output_dist_attrs[0]._is_partial()) | ||
|
||
def test_index_select_backward(self): | ||
# [-1, -1, -1], [0], [-1, 0, -1], axis=1 --> [-1, -1, -1], [0], [-1, 0, -1], [-1, -1, -1](partial on axis=1 with 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider:
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
x = paddle.arange(24).reshape([3, 4, 2])
index = paddle.to_tensor([0, 1, 1, 2], dtype='int32')
out = paddle.index_select(x=x, index=index, axis=1)
# if x_dims_mapping_src is [1, -1, -1], index_dims_mapping_src is [0], out_grad_dims_mapping is [1, 0, -1], what is shape of out/out_grad, what x_grad_dims_mapping should be? can it be partial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape of out/out_grad is [3, 4, 2], and x_grad_dims_mapping is same as x with [1, -1, -1], it should be partial in index's shared mesh dim. and i reuse the forward rule to infer dims_mapping and set partial_dims additionally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, if index is sharded, x_grad
will be partial(in axis) and partial_dims equals shard dim of index
… add_spmd_index_select
inferred_output_dist_attrs[0].dims_mapping, [-1, 0, -1] | ||
) | ||
|
||
# [-1, 1, -1], [0], axis=0 -> [-1, 1, -1], [0], [0, 1, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if [-1, 1, -1], [0], axis=1, what will be? should add unit this test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the x[axis] will be replicate first to [-1, -1, -1], so out_dims_mapping will be [-1, 0, -1]
/re-run all-failed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description