CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Paddle TensorRT] Support isnan, group_norm and take_along_axis #70817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { | ||
return false; | ||
} | ||
#if !IS_TRT_VERSION_GE(8200) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir-trt只支持8.6以上,这个不用检查了
return false; | ||
#else | ||
pir::Value index_var_name = op.operand_source(1); | ||
auto index_var_name_type = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个使用pir::GetDataTypeFromValue
pir::Value index_var_name = op.operand_source(1); | ||
auto index_var_name_type = | ||
index_var_name.type().dyn_cast<paddle::dialect::DenseTensorType>(); | ||
auto index_shape = index_var_name_type.dims(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
得到输入的shape使用pir::GetShapeFromValue
python/paddle/tensorrt/impls/math.py
Outdated
def isnan_converter(network, paddle_op, inputs): | ||
input_tensor = inputs[0] | ||
version_list = get_trt_version_list() | ||
if version_list >= [10, 1, 0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10.1版本先不需要写
python/paddle/tensorrt/impls/math.py
Outdated
|
||
equal_tensor = trt_equal(network, input_tensor, input_tensor) | ||
layer = network.add_unary(equal_tensor, trt.UnaryOperation.NOT) | ||
cast_layer = network.add_identity(layer.get_output(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
旧ir-trt没出现cast_layer,看一下为啥加这个
@@ -152,3 +157,39 @@ def instance_norm_converter(network, paddle_op, inputs): | |||
) | |||
instance_norm_layer = network.add_plugin_v2(instance_norm_inputs, plugin) | |||
return instance_norm_layer.get_output(0) | |||
|
|||
|
|||
@converter_registry.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
旧ir-trt是采用的通用plugin方式,这个是在哪里写的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TRT8.6新增normalization层,可以实现NCHW的group_norm,layer_norm的convert也是用的这个接口实现,TRT没有适配这个接口。
torch和onnx也用的这种方式。
self.max_shape = {"X": [5, 4, 10], "Index": [5, 4, 10]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试下fp16
self.max_shape = {"x": [5, 3]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试下fp16
self.max_shape = {"x": [6, 32, 64, 64]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试下fp16
PR Category
Inference
PR Types
Others
Description
card-71500
add isnan, group_norm and take_along_axis