CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[ Auto Parallel ]replace softmax_with_cross_entropy with c_softmax_with_cross_entropypass #68182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Auto Parallel ]replace softmax_with_cross_entropy with c_softmax_with_cross_entropypass #68182
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
e6e5741
to
636c46a
Compare
d4d774f
to
c18990c
Compare
d1c9616
to
c572acb
Compare
c572acb
to
ff62f94
Compare
70cd4dd
to
7c08d92
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a todo: add llama testing with this pass
@@ -340,7 +340,7 @@ class _DPOptimizationConfig(TypedDict, total=False): # noqa: PYI049 | |||
set_field_default_config( | |||
MP_OPTIMIZATION, "allreduce_matmul_grad_overlapping", False | |||
) | |||
|
|||
set_field_default_config(MP_OPTIMIZATION, "replace_with_c_softmax", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace_with_c_softmax
-> replace_with_parallel_cross_entropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
super().__init__() | ||
self._in_pir_mode = paddle.base.framework.get_flags( | ||
"FLAGS_enable_pir_api" | ||
)["FLAGS_enable_pir_api"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._in_pir_mode is not used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
python/paddle/distributed/passes/auto_parallel_replace_with_parallel_cross_entropy.py
Show resolved
Hide resolved
627d818
to
c741a52
Compare
c741a52
to
23fa385
Compare
706a82e
to
c0a9eeb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Performance
Description
基于PIR实现静半
parallel_cross_entroy
通信算子的标记优化,在 mp 列切 + hard label 的情况下,将softmax_with_cross_entropy
算子替换使用通信效率更高的c_softmax_with_cross_entropypass
算子perf in llama-7B dp1mp2pp1 with using pass
close pass: 323.2 token/card/s
open pass: 329.0 token/card/s
Pcard-70448