CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
add pd_op.mean and pd_op.sqrt converter for PIR TRT #67974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
87461a7
to
743e818
Compare
|
||
std::vector<int64_t> dims = dim_attr.data().GetData(); | ||
for (auto x : dims) { | ||
if (x == 0 || (x + input_shape.size() == 0)) return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段关于非动态shape的删除,pir-trt只支持动态shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
!input_type.isa<pir::Int64Type>() && | ||
!input_type.isa<pir::Float32Type>() && | ||
!input_type.isa<pir::Float64Type>()) { | ||
VLOG(3) << "The type of input is not int32 or int64 or float32 or float64"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加一个算子名称,比如pd_op.mean只支持这四种类型输入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
edd793b
to
4919f7e
Compare
@@ -0,0 +1,54 @@ | |||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMakeLists.txt加上时间限制,否则ci超时
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
2ecf5e2
to
74d0070
Compare
00fb00c
to
3ec1c8d
Compare
|
||
@converter_registry.register("pd_op.sqrt_", trt_version="8.x") | ||
@converter_registry.register("pd_op.sqrt", trt_version="8.x") | ||
def mean_converter(network, paddle_op, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注册的是sqrt,名字叫mean_converter?另外sqrt代码应该是在ops.py下,不是unary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我改下
@@ -0,0 +1,38 @@ | |||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,单测需要放到test_converter_ops.py里
@@ -0,0 +1,46 @@ | |||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mean对应paddle api里在stat.py里
@@ -0,0 +1,56 @@ | |||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件命名test_converter_stat.py
618084a
to
5a78b82
Compare
4cdd84d
to
c16b3ee
Compare
PR Category
Inference
PR Types
New features
Description
pcard-71500
添加了pd_op.mean和pd_op.sqrt到PIR TRT转换支持