CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR-Auto-Parallel]refactor recompute pass in PIR mode #69681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR-Auto-Parallel]refactor recompute pass in PIR mode #69681
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
2、set `recomput_id` attr for recompute ops
d40ff3c
to
2f7cdb3
Compare
return segment_num + 1, rc_op_num | ||
|
||
def run_test_cases(self): | ||
self.strategy._recompute.enable = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why strategy.recompute
used in pass but strategy._recompute
used here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strategy.recompute
(in pass) is the recompute
attribute of class paddle.distributed.Strategy()
strategy._recompute
(in engine) is the _recompute
attribute of class paddle.distributed.fleet.auto.Strategy()
The two classes all describe the configuration information of recompute
and can be converted in to_static
funxtion
assert ( | ||
base_segment_num < segment_num_1 | ||
and segment_num_1 < segment_num_2 | ||
and segment_num_2 < segment_num_3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check results more accurate.
eg.
assert base_segment_num == XX
assert op_num == XX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rc_end_id = len(block.ops) | ||
for idx in range(rc_begin_id, rc_end_id): | ||
rc_op = block.ops[idx] | ||
rc_op.set_int_attr("fwd_recompute_id", _g_recompute_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use the prefix "fwd" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwd_compute_id
indicates the checkpoint information in the forward that needs to be recomputed.
bwd_recompute_id
corresponds to fwd_recompute_id
and is newly added in the backward for recompute, which facilitates debugging.
There are comments in the code to provide explanations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Performance
Description
基于
PIR
对重计算 pass 进行重构(参考旧
IR
下重计算实现:#38920 )PCard-88114