CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Move fused_moe_permute and fused_moe_unpermute #73264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
/re-run all-failed |
3 similar comments
/re-run all-failed |
/re-run all-failed |
/re-run all-failed |
/re-run all-failed |
/re-run all-failed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加到incubute目录下,不用新开文件
/re-run all-failed |
1 similar comment
/re-run all-failed |
/re-run all-failed |
1 similar comment
/re-run all-failed |
def moe_permute( | ||
X: Tensor, | ||
XScale: Tensor | None, | ||
expert_routemap_topk: Tensor, | ||
expert_prob_topk: Tensor, | ||
topk: int, | ||
num_experts: int, | ||
tokens_per_expert: list, | ||
padding_multiplex: int, | ||
name: str | None = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加完备API注释说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我之后单独提PR加上
from paddle import Tensor | ||
|
||
|
||
def moe_unpermute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
/re-run all-failed |
1 similar comment
/re-run all-failed |
/re-run approval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类型提示、命名规范可以在下个 PR 修改
tokens_per_expert: list, | ||
padding_multiplex: int, | ||
name: str | None = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加输出类型提示
X: Tensor, | ||
XScale: Tensor | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么 Python 端还会有大写 X
、XScale
这种变量名呢?用 x
、x_scale
num_experts: int, | ||
MP: bool = True, | ||
name: str | None = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
unzipped_token_probs: Tensor, | ||
total_zipped_tokens: int, | ||
num_experts: int, | ||
MP: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MP表示什么含义?用全称不用缩写
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与AMP对应,是MixPrecision的缩写,这个后续我会单独和文档一起推进规范化
return max_abs_err, max_rel_err | ||
|
||
|
||
class TestFusedMoePermuteUnpermute(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议补充小算子组合实现的版本,作为参考答案,验证API的正确性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
print( | ||
f"permute-unpermute tokens relative error: {max_rel_err}" | ||
) | ||
self.assertLess( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里误差的原因是什么?是否可以做到二进制完全对齐?误差范围为什么设置为1e-2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
存在一个token广播给多个专家的情况,在收集时会涉及到token的加权&reduce操作,会带来一定的浮点精度损失,这个损失来源于两部分:1.浮点乘法 2.累加误差,设置为1e-2是参考了bfloat16 reduce的阈值
print( | ||
f"ermute-unpermute probs max absolute error: {max_abs_err}, relative error: {max_rel_err}" | ||
) | ||
self.assertLess( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有误差的原因是什么?是否可以做到二进制完全对齐?误差范围为什么设置为1e-5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
/re-run all-failed |
for dt in self.DTYPES: | ||
for expert_num in self.EXPERT_NUMS: | ||
for topk in self.TOPKS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for dt in self.DTYPES: | |
for expert_num in self.EXPERT_NUMS: | |
for topk in self.TOPKS: | |
for (dt, expert_num, topk) in itertools.product(self.DTYPES, self.EXPERT_NUMS, self.TOPKS): |
使用 itertools.product
折叠多个 for 循环,减少锁进层级
): | ||
print( | ||
f"Testing with {expert_num} experts, topk {topk}, dtype {dt}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
刚刚没注意,这几个无条件的 print
也要删掉,如果要打印误差,利用 self.assertEqual
等 API 的 err_msg
字段,不要无条件打印日志,会影响 CI 日志
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Improvements
Description
Add
paddle.nn.functional.fused_moe_permute
&paddle.nn.functional.fused_moe_unpermute
kernels/fusion/gpu
pcard-91067