CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
support flash attention with sparse mask #62029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Sorry to inform you that 17ed1c6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
17ed1c6
to
f2ae287
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改文档以符合规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for YAML
@@ -859,6 +859,17 @@ | |||
func : flash_attn_unpadded_grad | |||
data_type: q | |||
|
|||
- backward_op : flash_attn_with_sparse_mask_grad | |||
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) | |||
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续该算子若要支持自动并行,请添加切分推导规则
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
API is fine, please attach the Chinese document PR link in description above. |
添加相关中文文档:PaddlePaddle/docs#6554 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API
* add flash attention with sparse mask * fix doc * Update python/paddle/nn/functional/flash_attention.py * Update python/paddle/nn/functional/flash_attention.py * Update python/paddle/nn/functional/flash_attention.py * Update python/paddle/nn/functional/flash_attention.py * fix docstring --------- Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> Co-authored-by: zachary sun <sunzhongkai@baidu.com>
This reverts commit e05764a.
PR types
New features
PR changes
APIs
Description
support flash attention with sparse mask
Pcard-73145