CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Auto Parallel] Support gradient_merge pass for PIR #66641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto Parallel] Support gradient_merge pass for PIR #66641
Conversation
… fit_grad_merge_for_pir
你的PR提交成功,感谢你对开源项目的贡献! |
break | ||
for _, new_grad in new_params_to_grads: | ||
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False) | ||
new_grad.get_defining_op().op_role = int(OpRole.Backward) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的scale_
算子应该归类到OpRole.Optimize
中
def _pir_remove_cast_for_master_grad(main_program): | ||
main_block = main_program.global_block() | ||
for op in main_block.ops: | ||
if _is_master_grad_cast_op(main_block, op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_is_master_grad_cast_op
函数不一定可以在pir下继续使用,需要结合PIR的amp功能进行调整。
main_block = main_program.global_block() | ||
|
||
for idx, op in list(enumerate(main_block.ops)): | ||
if is_data_parallel_reduce_op(op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_data_parallel_reduce_op
函数功能是否保留,需要进一步讨论。
test/auto_parallel/pir/mlp_demo.py
Outdated
|
||
gradient_merge = strategy.gradient_merge | ||
gradient_merge.enable = True | ||
gradient_merge.k_steps = accumulate_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以增加一个带有gradient_merge的单测,并对比loss结果
… fit_grad_merge_for_pir
… fit_grad_merge_for_pir
"pd_op.c_allreduce_sum", | ||
"pd_op.c_allreduce_avg", | ||
"pd_op.c_reduce_sum", | ||
"pd_op.c_reduce_avg", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有 op 遗漏,比如 sharding 的 reducescatter op。 判断 dp 梯度同步 op 的逻辑会在很多 pass 中复用,建议把其写成一个公共的 utils,统一管理。 当前 PR 逻辑影响不大,可以在后续 sharding pr 中升级修改。 @winter-wang
# NOTE(sonder): When "@RENAME@" is in the input name, it means that the op has been renamed. | ||
# Such types input names are caused by shared parameter policy. | ||
# Gradient merge should accumulate the gradient of ops without renaming. | ||
if "@RENAME" in op_input_names[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"@rename" 作为GM 逻辑中关键标志,最好不要裸写在代码的每一个地方。最好定义一个 GM_SUFFIX 全局变量统一管理。 避免多处地方不同修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"@rename" 不是 grad_merge 引入的,是共享参数策略带来的
op.op_role = OpRole.Optimize | ||
main_block.move_op_to_block_end(op) | ||
|
||
if op.name() in ["pd_op.c_allreduce_sum", "pd_op.c_reduce_sum"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如何区分 TP 中在反向的 allreduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续通过在 dp 中添加 allreduce 算子的 attr 解决这个问题
@@ -526,12 +618,48 @@ def parse_program( | |||
return grad_to_gradient_merge | |||
|
|||
|
|||
def _pir_parse_program( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉缺少一个关键逻辑: 《GM 的子图调度编排》
PIR 版本的 GM 实现需要相对于原有 programIR 的 GM 升级最关键的点:
programIR 的 GM:控制流实现 GM 调度
PIR 的 GM:通过类似 PP 子图调度编排实现 GM 的调度,最好 PP 和 GM 复用同一套框架逻辑。
将计算图切分成 FW-BW-OPT 三个子图,通过编排不同子图的调度顺序实现 GM 逻辑(类似 流水线的 调度编排)
如:FW-BW-FW-BW-FW-BW-FW-BW-OPT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个这一版的 pr 先不考虑,后续方案确定后再实现
gard_defining_op = grad.get_defining_op() | ||
paddle.pir.set_insertion_point_after(gard_defining_op) | ||
|
||
new_gradient_merge_var = main_block.add_kwarg( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要考虑和 AMP 中的 master grad 逻辑兼容,如果AMP 中 开启了 master grad,累加的 persistable 的 grad 应该和 amp 复用同一个 master grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
def _pir_remove_cast_for_master_grad(main_program): | ||
main_block = main_program.global_block() | ||
for op in main_block.ops: | ||
if _is_master_grad_cast_op(main_block, op): | ||
main_program.remove_op(op) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的适配逻辑还有些问题,下一个 pr 修复 并添加 amp 情况下的单测
PR Category
Auto Parallel
PR Types
Not User Facing
Description
Pcard-76459
为 PIR 适配 gradient_merge pass
依赖 PR: