CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Add phi kernel fp8_fp8_half_gemm_fused #64955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
cudaGetDevice(&dev); | ||
if (dev == 0) { | ||
std::ofstream outfile; | ||
outfile.open(config_filename_, std::ios::out | std::ios::trunc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
创建文件的方式可以再优化~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来放在析构还是合适的,整体功能后续后迁出,做成一个工具api
2d9da2d
to
2baf1be
Compare
infile.close(); | ||
} | ||
|
||
std::string config_filename_{"/tmp/paddle_cublaslt_cache"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可配置参数,默认值是当前文件夹
return &algo_in_map; | ||
} | ||
|
||
~CublasLtAlgoCache() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape_range_info在析构时候将信息序列化到磁盘,可以参考https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/api/analysis_predictor.cc#L3196
ed53665
to
375099d
Compare
int64_t value) { | ||
*seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面调优代码看起来是针对矩阵通用的,不限于FP8矩阵计算,后续需要抽取作为通用调优代码。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,全局搜索矩阵乘调优后面会做成一个工具
&epilogue, | ||
sizeof(epilogue)); | ||
PADDLE_CUBLASLT_STATUS_CHECK(cublasLtMatmulDescSetAttribute); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要增加else分支
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加了 else 的 PADDLE_THROW errors
ctx, batch_count, m, n, k, x, y, scale, bias, activation_type, out); | ||
} | ||
|
||
} // namespace cutlass_internal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面代码是cublaslt,而不是cutlass_xx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,namespace复用了,和cutlass的统一修改
|
||
namespace phi { | ||
namespace fusion { | ||
namespace cutlass_internal { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cublaslt_internal ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和cutlass的kernel拆开的时候忘记区分了,下一个cutlass的pr是和这个pr共用了 phi和api,不太好分开改namespace,最好是两边都合入后统一改成 fp8_internal
return False | ||
if get_cuda_version() < 12010: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面代码和上一个PR是重复的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测中是的,这个代码似乎不太好做成通用的吧?
return False | ||
if get_cuda_version() < 12010: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上, 重复代码
) | ||
# there exists some problem in cpu fp8 cast | ||
if self.device == "gpu": | ||
self.assertTrue(paddle.equal_all(input2, expect)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU这里有啥问题?后续需要修复吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里忘记删除单测的skip,cast、full、reshape之类的非计算的数值操作 OP 都是支持FP8的;
matmul只在固定GPU和CUDA版本才支持
expect = paddle.to_tensor([[1, 1]]).astype("float32") | ||
# there exists some problem in cpu fp8 full | ||
if self.device == "gpu": | ||
self.assertTrue(paddle.equal_all(expect, input_fp32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去除单测的skip
paddle.cast(output_bf16, "float32"), | ||
paddle.to_tensor(expect_result), | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面计算了gelu/relu、bias组合的,但这里都没有检查结果正确性。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,增加 gelu/relu、bias组合的结果正确性检查
python/paddle/tensor/linalg.py
Outdated
@@ -324,6 +324,164 @@ def __check_input(x, y): | |||
return out | |||
|
|||
|
|||
def fp8_fp8_fp16_gemm_fused( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API放这里不太合适
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
通过增加一个属性来控制输出dtype呢?这样就能用一个api了,可以减少很多重复代码吧
这个之前考虑过,通过属性来控制没有直接api控制更方便直接, |
#define PADDLE_CUBLASLT_STATUS_CHECK(name) \ | ||
PADDLE_ENFORCE_EQ( \ | ||
status, \ | ||
CUBLAS_STATUS_SUCCESS, \ | ||
phi::errors::External( \ | ||
#name \ | ||
"execution error" \ | ||
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more " \ | ||
"information")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个paddle没有提供类似的工具吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有cublas 的 error 枚举定义吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那你把这块加到这里吧 paddle/phi/core/enforce.h
|
||
namespace phi { | ||
namespace fusion { | ||
namespace cutlass_internal { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉没必要新加这个namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和后续的cutlass的kernel一起开发的,复用同一套phi、api定义,在同一哥namespace下实现的,还没有做区分,可以合入后可以改成 fp8_internal,应该是需要namespace的,cutlass实现的时候有一些函数和定义可能其它地方会重名
#include "paddle/phi/common/place.h" | ||
#include "paddle/phi/core/allocator.h" | ||
|
||
namespace dyl = phi::dynload; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个省略感觉没必要吧,本身也没有几个字母
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有六十多个地方用到了,还挺多的
infile.close(); | ||
} | ||
|
||
std::string config_filename_{"./paddle_cublaslt_cache"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个路径是什么?这样使用相对路径不合适吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前是以当前路径,之后会把整体功能抽出,可配置各种参数
HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset)); | ||
} | ||
|
||
void HashValue_(int64_t* seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paddle的编码规范里没有类成员函数以 下划线 结尾的要求和先例
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改成普通成员函数的命名
7da01a4
to
51f6a21
Compare
改成一个API也行 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for your excellent work
LBTM for PR-CI-Paddle-Doc-Preview
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for const_cast and @unittest.SkipIf
class TestFP8CastOp(unittest.TestCase): | ||
def setUp(self): | ||
self.dtype_dict = { | ||
"float8_e4m3fn": core.VarDesc.VarType.FP8_E4M3FN, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议不再依赖core.VarDesc,直接使用phi下的DataType,类似paddle.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no docs changes, LGTM
PR Category
Performance Optimization
PR Types
New features
Description
pcard-71500
Add api, phi_kernel of fp8_fp8_half_gemm_fused, can fuse gemm+bias+scale+act (cuBLASLt: Global search optimization)