CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Hackathon 7th No.32】为 paddle.nn.functional.scaled_dot_product_attention 进行功能增强 #69099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
好的谢谢!关于paddle的后端选择我看了,现在代码是针对padde的后端选择修改的。并同时修改了下mask相关。
请问我这里有必要和torch的强制对齐吗,还是说让paddle flash_attention已经支持的后端在paddle 的sdpa下继续支持好就OK Reference: torch sdpa code |
Sorry to inform you that 2f261aa's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
所以是支持torch其中的 Flash GPU、Efficient、Math三种对吧,剩余CUDNN、Override不支持。目前paddle里已经支持的后端需要对齐torch,其他现在没有的后端就不要了 |
是的,了解了。谢谢~ |
已修复PR,添加:
烦请 @zhwesky2010 再帮忙review一下!谢谢 |
顺便贴一下测试代码: import torch
from torch.backends.cuda import sdp_kernel as torch_sdp_kernel
import paddle
from paddle.nn.functional import sdp_kernel as paddle_sdp_kernel
import numpy as np
from numpy.testing import assert_allclose
def create_float_attention_mask(batch_size, seq_lens, dtype='float16'):
max_seq_len = max(seq_lens)
# Paddle mask with 1e4
paddle_mask = paddle.zeros([batch_size, 1, max_seq_len, max_seq_len], dtype=dtype)
for i in range(batch_size):
seq_len = seq_lens[i]
mask = paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=dtype)) - 1
paddle_mask[i, 0, :seq_len, :seq_len] = mask * 1e4
# Torch mask with 1e9
torch_mask = paddle.zeros([batch_size, 1, max_seq_len, max_seq_len], dtype=dtype)
for i in range(batch_size):
seq_len = seq_lens[i]
mask = paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=dtype)) - 1
torch_mask[i, 0, :seq_len, :seq_len] = mask * 1e9
torch_mask = torch_mask.numpy()
torch_mask = torch.from_numpy(torch_mask).to(dtype=torch.float16).cuda()
return paddle_mask, torch_mask
def test_sdpa_alignment(batch_size=2, seq_len=1024, num_heads=8, head_dim=64):
np.random.seed(42)
torch.manual_seed(42)
paddle.seed(42)
query_np = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(np.float16)
key_np = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(np.float16)
value_np = np.random.randn(batch_size, seq_len, num_heads, head_dim).astype(np.float16)
paddle_mask, torch_mask = create_float_attention_mask(batch_size, [seq_len] * batch_size, dtype='float16')
query_paddle = paddle.to_tensor(query_np)
key_paddle = paddle.to_tensor(key_np)
value_paddle = paddle.to_tensor(value_np)
query_torch = torch.from_numpy(query_np.transpose(0, 2, 1, 3)).cuda()
key_torch = torch.from_numpy(key_np.transpose(0, 2, 1, 3)).cuda()
value_torch = torch.from_numpy(value_np.transpose(0, 2, 1, 3)).cuda()
# Flash Attention
with paddle_sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False), \
torch_sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
try:
paddle_flash = paddle.nn.functional.scaled_dot_product_attention(
query_paddle, key_paddle, value_paddle,
attn_mask=paddle_mask,
dropout_p=0.0,
is_causal=False
).numpy()
torch_flash = torch.nn.functional.scaled_dot_product_attention(
query_torch, key_torch, value_torch,
attn_mask=None,
dropout_p=0.0,
is_causal=True
).cpu().numpy().transpose(0, 2, 1, 3)
print("\nComparing Flash Attention:")
abs_diff = np.abs(paddle_flash - torch_flash)
print(f"Maximum absolute difference: {np.max(abs_diff)}")
print(f"Mean absolute difference: {np.mean(abs_diff)}")
assert_allclose(paddle_flash, torch_flash, rtol=1e-2, atol=1e-2)
print("Flash Attention results match within tolerance")
except Exception as e:
print(f"Flash Attention comparison failed: {e}")
# Math Attention
with paddle_sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False), \
torch_sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
try:
paddle_math = paddle.nn.functional.scaled_dot_product_attention(
query_paddle, key_paddle, value_paddle,
attn_mask=paddle_mask,
dropout_p=0.0,
is_causal=False
).numpy()
torch_math = torch.nn.functional.scaled_dot_product_attention(
query_torch, key_torch, value_torch,
attn_mask=torch_mask,
dropout_p=0.0,
is_causal=False
).cpu().numpy().transpose(0, 2, 1, 3)
print("\nComparing Math Attention:")
abs_diff = np.abs(paddle_math - torch_math)
print(f"Maximum absolute difference: {np.max(abs_diff)}")
print(f"Mean absolute difference: {np.mean(abs_diff)}")
assert_allclose(paddle_math, torch_math, rtol=1e-2, atol=1e-2)
print("Math Attention results match within tolerance")
except Exception as e:
print(f"Math Attention comparison failed: {e}")
# Memory Efficient Attention
with paddle_sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True), \
torch_sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
try:
paddle_mem = paddle.nn.functional.scaled_dot_product_attention(
query_paddle, key_paddle, value_paddle,
attn_mask=paddle_mask,
dropout_p=0.0,
is_causal=False
).numpy()
torch_mem = torch.nn.functional.scaled_dot_product_attention(
query_torch, key_torch, value_torch,
attn_mask=torch_mask,
dropout_p=0.0,
is_causal=False
).cpu().numpy().transpose(0, 2, 1, 3)
print("\nComparing Memory Efficient Attention:")
abs_diff = np.abs(paddle_mem - torch_mem)
print(f"Maximum absolute difference: {np.max(abs_diff)}")
print(f"Mean absolute difference: {np.mean(abs_diff)}")
assert_allclose(paddle_mem, torch_mem, rtol=1e-2, atol=1e-2)
print("Memory Efficient Attention results match within tolerance")
except Exception as e:
print(f"Memory Efficient Attention comparison failed: {e}")
if __name__ == "__main__":
for seq_len in [512]:
print(f"\nTesting with sequence length: {seq_len}")
test_sdpa_alignment(
batch_size=2,
seq_len=seq_len,
num_heads=8,
head_dim=64
) cc: @zhwesky2010 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
返回格式的transpose(0, 2, 1, 3)那里,是因为Tensor的layout不同吗,这个diff是合理的不
|
||
if "xpu" in place: | ||
if place == "XPU": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在过xpu的ci的时候。。。我发现在xpu下,使用tensor的.place会返回一个“XPU”的string
@@ -318,7 +374,7 @@ def flash_attention( | |||
|
|||
""" | |||
head_dim = query.shape[3] | |||
sdp_func_name = _select_sdp(head_dim) | |||
sdp_func_name = _select_sdp(head_dim, query.dtype, query.place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query.place
这个取值范围不是XPU、CPU、GPU吧,上面的判断是怎么命中的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XPU、CPU、GPU这几个string是为了保险,有string的变量落入_select_sdp里的风险。在过CI的时候就会碰到,预期的形式肯定是能被.is_gpu_place()处理的这样的
return "flash_attn" | ||
|
||
# not use sdp_kernel | ||
if g_enable_flash is None: | ||
if "gpu" not in place: | ||
arch = _get_arch_info() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有gpu的时候才需要_get_arch_info吧,这个可以挪到下面去
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK,已修改
|
||
# handle bfloat16/fp16 case | ||
elif place.is_gpu_place(): | ||
if dtype == paddle.bfloat16 or dtype == paddle.float16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个分支的逻辑和torch都对齐了吗,是一致的吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没对齐,我再参考一下torch的方案修改下
是的,这个是合理的。paddle 的mem efficient 也需要transpose一下 在pytorch里 fa接受的输入为[bs, num_head, seq_length, head_dim] 在phi里paddle的fa应该是内部专置了,但是math和 mem efficient支持的都是[bs, num_head, seq_length, head_dim],这种。所以它们俩手动专置了。 |
pytorch还是能通过的,等我再截个图。第二个事我想想怎么搞一下 |
![]() |
@zhwesky2010 test已加 |
|
||
# not use sdp_kernel | ||
if ( | ||
g_enable_flash is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这三个变量是在哪里设置呢,好像用户也没办法指定?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有办法的
from paddle.nn import sdp_kernel
with sdp_kernel(
enable_flash=False, enable_math=True, enable_mem_efficient=False
):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这三个变量是在哪里设置呢,好像用户也没办法指定?
见最后一个test,就是这么改的。能手动控制使用哪个kernel打开或者默认让后端选择
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@yinfan98 看看CI不通过的原因吧,不要引入不兼容风险。之前使用flash_attn后端计算的场景,没有改变吧,不然可能导致性能的改变 |
@zhwesky2010 没变 flash_attn 还是使用的之前的一套,后端选择也是一样的 |
那合入吧 |
LGTM |
@zhwesky2010 这个coverage挂了我不知道该怎么修一下,看起来好像不是因为我导致的。。。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@yinfan98 修复下windows流水线 |
rerun次数到了 我整体重跑下吧。。 |
PR Category
User Experience
PR Types
Improvements
Description
修改paddle sdpa代码,支持后端 math、mem efficient、flash选择。并对齐torch选择代码的方式