CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[DistDialect] add reshard op and api #62718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DistDialect] add reshard op and api #62718
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
const phi::distributed::ProcessMesh& process_mesh, | ||
const std::vector<int64_t>& dims_mapping) { | ||
pir::IrContext* ctx = pir::IrContext::Instance(); | ||
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能使用默认 partial_status
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前版本可以这样,未来再修改
pir::AttributeMap attributes) { | ||
VLOG(4) << "Start build ReShardOp"; | ||
|
||
// Temporary restriction, will support input use_empty false in the future |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard 不需要这个限制
|
||
PADDLE_ENFORCE(attributes.find("tensor_dist_attr") != attributes.end(), | ||
phi::errors::NotFound( | ||
"'tensor_dist_attr' Attribute is expected for ShardOp")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshardop
@@ -287,6 +287,36 @@ TEST(shard_tensor_op_replicate_test, base) { | |||
EXPECT_EQ(shard_op.attribute<OperationDistAttribute>("op_dist_attr") | |||
.process_mesh_attr(), | |||
mesh_attr); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测应该包括:
case1: 改变 mesh
case2: 改变 dims mapping
case3: 同时改变 mesh 和 dims mapping
@@ -386,6 +386,16 @@ def reshard(dist_tensor, mesh, placements): | |||
dist_attr._set_partial_dims(partial_dims) | |||
|
|||
return paddle.base.core.reshard(dist_tensor, dist_attr) | |||
elif paddle.framework.in_pir_mode(): | |||
assert isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在还没给 reshard op 写对应的 反向逻辑, 还不能支持 reshard op 前向组网
void ReShardOp::Build(pir::Builder& builder, | ||
pir::OperationArgument& argument, | ||
pir::Value input, | ||
pir::AttributeMap attributes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshard OP 没有其他 attribute, dst_tensor_dist_attr 直接当做 argument 传入build 会更清晰?
auto shard_size = process_mesh_shape[dims_mapping[i]]; | ||
PADDLE_ENFORCE( | ||
global_dims[i] % shard_size == 0, | ||
phi::errors::PreconditionNotMet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
phi::errors::PreconditionNotMet( | |
common::errors::PreconditionNotMet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errors都是定义在common命名空间的。建议直接使用common命名空间。 这样可以和pir、phi、common库同时保持一致。
f22a2c8
to
ad534b9
Compare
a1c3b81
to
34de7ac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Pcard-67164