CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[XPU] Support bf16 clip_grad and redirect xpu kernel to clamp_grad #69723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU] Support bf16 clip_grad and redirect xpu kernel to clamp_grad #69723
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -72,7 +72,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, | |||
PADDLE_ENFORCE_EQ(r, | |||
0, | |||
common::errors::Fatal( | |||
"cast_v2 failed, related variable `r` is %d", r)); | |||
"cast failed, related variable `r` is %d", r)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以在后续PR中给它改成PADDLE_ENFORCE_XDNN_SUCCESS
的写法
reinterpret_cast<XPUDataType*>(x_grad->data<T>()), | ||
x.numel(), | ||
static_cast<XPUDataType>(min.to<T>()), | ||
static_cast<XPUDataType>(max.to<T>())); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的clip_grad
可考虑改成clamp_grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,后续PR修改
@@ -274,69 +274,135 @@ void XPUFusedRotaryEveryTwo(const Context& dev_ctx, | |||
DenseTensor* out_q, | |||
DenseTensor* out_k, | |||
DenseTensor* out_v) { | |||
auto single_func = &xpu::rotary_embedding_v3_single<XPUType, XPUSCType>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
临时修改,待xhpc更新后再优化
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个会有正确性问题吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会有正确性问题,现在前向和反向接口不一致,一个是std::string,一个是char *,所以写了个大if分开
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
"BLHD", | ||
true); | ||
PADDLE_ENFORCE_XDNN_SUCCESS(ret, single_func_name); | ||
if (is_bwd) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里前反向分开处理的原因是啥呀?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xhpc里前向和反向接口不一致,一个是std::string,一个是char *,无法共用代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
了解,所以后面应该统一用std::string是吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是统一用char *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Custom Device
PR Types
New features
Description