CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[auto parallel] Support 1F1B for PIR #66810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… fit_grad_merge_for_pir
… fit_grad_merge_for_pir
你的PR提交成功,感谢你对开源项目的贡献! |
… fit_1f1b_for_pir
), "PIR does not support 1F1B with enable_send_recv_overlap yet." | ||
|
||
types = [FORWARD, BACKWARD, OPT] | ||
sub_program_list = _pir_program_for_fthenb_and_1f1b( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个切图功能应该通用化,给 gradmerge 那边应该也能用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个切图功能本来就是通用化的,可能函数命名有误解,它的主要作用就是把 program 分成 forward, backward, opt 三个子图
print("loss_fthenb", loss_fthenb) | ||
print("loss_1f1b", loss_1f1b) | ||
self.assertTrue( | ||
np.allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check 太松了
如果只是 1f1b 和 fthenb 间的对比,应该可以做到 all_equal
pipeline = strategy.pipeline | ||
pipeline.enable = True | ||
pipeline.schedule_mode = "1F1B" | ||
pipeline.accumulate_steps = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pp2-acc2 这个并行度有点低,建议后续可以补一个更高并行度的单测,并且额外测试和现在已经有个功能策略的兼容性,比如:tp2-pp4-acc8-amp-o2。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
print("loss_1f1b", loss_1f1b) | ||
self.assertTrue( | ||
np.allclose( | ||
loss_fthenb, loss_1f1b, rtol=self.rtol, atol=self.atol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外端到端的 精度对齐 check 太粗了,最好补一些核心逻辑的功能性 check,比如 check 编排逻辑,比如 pp=4,acc=8 时,1f1b 和 fthenb 在不同 rank 下的编排调度序是否符合预期
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加相关单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… fit_1f1b_for_pir
… fit_1f1b_for_pir
…into fit_1f1b_for_pir
if mode == "train" and self._strategy.pipeline.enable: | ||
self._strategy.gradient_merge.enable = True | ||
self._strategy.gradient_merge.k_steps = ( | ||
self._strategy.pipeline.accumulate_steps | ||
) | ||
self._strategy.gradient_merge.avg = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[New add] 这里补全了开启 pipeline 的情况下对于 grad_merge 相关 strategy 配置的修改
@@ -285,7 +289,7 @@ def _pir_append_gradient_merge_backward_op( | |||
with startup_block: | |||
paddle.pir.set_insertion_point_to_block_end(startup_block) | |||
gradient_merge_var = paddle.full( | |||
shape=grad.shape, fill_value=float(0), dtype=grad.dtype | |||
shape=grad._local_shape, fill_value=float(0), dtype=grad.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[New add] 这里应该是 grad._local_shape, 直接用 grad.shape 的话 grad 被 shard 了会报错,这个 case 是修改 semi_auto_llama_acc_align.py 开启 grad_merge 的情况下测出来的
opt_ops_use_grad = [ | ||
op | ||
for op in grad.all_used_ops() | ||
if op.op_role == int(OpRole.Optimize) | ||
] | ||
grad.replace_grad_users_with( | ||
new_gradient_merge_var, set(opt_ops_use_grad) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[New add] 上一个版本这里的代码有问题,应该只修改 opt op 的输入
"pd_op.c_reduce_sum", | ||
"pd_op.c_reduce_avg", | ||
]: | ||
if op.name() in comm_ops: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[New add] 替换成了 pir_utils 里面维护的 comm_ops
new_grad.get_defining_op().op_role = int(OpRole.Optimize) | ||
scale.get_defining_op().op_role = int(OpRole.Optimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[New add] 修正 scale 和 full 的 op role
PR Category
auto parallel
PR Types
Not User Facing
Description
Pcard-76459
为 PIR 适配 1F1B pass