CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
dist.to_static support pir program #62560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -37,18 +39,45 @@ void DistDialect::initialize() { | |||
void DistDialect::PrintType(pir::Type type, std::ostream &os) const { | |||
if (auto dist_dense_tensor_type = type.dyn_cast<DistDenseTensorType>()) { | |||
// Todo: Design the dist dense tensor type print format. | |||
os << dist_dense_tensor_type.dense_tensor_type(); | |||
os << type.dialect().name(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os << type.dialect().name(); | |
os << name(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostlyt, its equal, but I wonder if there is a type which is not belong to dist dialect?
) | ||
|
||
if isinstance(var_spec, DistributedInputSpec): | ||
dist_dense_tensor_type = paddle.base.libpaddle.pir.create_dist_dense_tensor_type_by_dense_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to change to use shard_tensor python API in future when this API is adapted for PIR. @hitywt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
main_program = dist_model._engine._fwd_main_progs["train"] | ||
for op in main_program.global_block().ops: | ||
tensor = op.result(0) | ||
if op.name() == 'pd_op.data': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable check for "builtin.parameter" after shard_tensor API is adapted for PIR. @hitywt
self.assertEqual(tensor.process_mesh.process_ids, [0, 1]) | ||
self.assertEqual(tensor.dims_mapping, [-1, -1]) | ||
self.assertEqual(tensor.partial_dims, set()) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable check for all other forward computation op after build is adapted for disttensor. @winter-wang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* auto_parallel engine build pir program * skip prepare_op_amp_options in build_program * add ut * fix cmake * remove print
* auto_parallel engine build pir program * skip prepare_op_amp_options in build_program * add ut * fix cmake * remove print
* auto_parallel engine build pir program * skip prepare_op_amp_options in build_program * add ut * fix cmake * remove print
PR types
New features
PR changes
APIs
Description
dist.to_static support pir program
Pcard-76459
In the original static auto-parallel process, the dynamic model is converted into serial program without DistTensors first, and the the dist attributes are added to the program.
However, the new static auto-parallel implementation will infer the spmd info during build operation, so the inputs and parameters in Dynamic model which are DistTensors should be converted to static
DistDenseTensorType
natively and eagerly.This PR handles the inputs which is sharded by dataloader, the parameters will be handled in the next pr.
For example, given the follow DemoNet, call
shard_dataloader
to makeThe serial main program is mixed with DenseTensor and DistDenseTensor.
Currently, the inputs of model,
input0
andlabel0
, are DistDenseTensor.TODO: MAKE parameters be DistDenseTensors.