CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Hackathon 7th No.32】为 paddle.nn.functional.scaled_dot_product_attention 进行功能增强 #70166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
}, | ||
) | ||
return out | ||
elif sdp_func_name == "mem_efficient": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有个疑问:看 nn.functional.flash_attention
的mem_efficient分支是直接调用的memory_efficient_attention
函数,而这个是用的variable_length_memory_efficient_attention
加的一些组装的逻辑,是因为attn_mask的区别吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes,memory_efficient_attention不支持mask。调用了这个支持mask的并吧seq_lens展平加入
@@ -191,6 +306,54 @@ def _select_sdp(head_dim: int) -> str: | |||
return "mem_efficient" | |||
|
|||
|
|||
def _select_sdp_for_sdpa(query, key, attn_mask, dropout, is_causal) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个与之前的逻辑是兼容的吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和之前是并行的,flash用不到这个接口
sdp_func_name = _select_sdp_for_sdpa( | ||
query, key, attn_mask, dropout_p, is_causal | ||
) | ||
|
||
if attn_mask is None: | ||
# downgraded to ordinary flash attention implementation | ||
out, _ = flash_attention(query, key, value, dropout_p, is_causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是直接调用上面的 flash_attention
API,会跑一遍 sdp_func_name = _select_sdp(head_dim)
重置后端,这个和上面的 _select_sdp_for_sdpa
后端选择逻辑是否冲突?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我理解这里的逻辑是,
- 如果这个输入是不带mask的,它会走下面flash的逻辑。并最后由_select_sdp确定使用哪种算法,所以不带mask的能和之前的版本完全对齐。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也就是说,调用新的_select_sdp之后,之前的_select_sdp_for_sdpa就被覆盖了。这是符合预期的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yinfan98 我理解是
- 不带mask版本的,与之前完全一致
- 带mask版本的,采用了新的_select_sdp_for_sdpa,这个和之前是否兼容呢
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
User Experience
PR Types
Improvements
Description
修改paddle sdpa代码,支持后端 math、mem efficient、flash选择。并对齐torch选择代码的方式
ps:帮忙CR一下还需要补充哪些单测吧