CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【TRT_Converter】add tile & share_data op TRT_Converter #69086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -1485,6 +1485,25 @@ class TanhOpPattern : public pir::OpRewritePattern<paddle::dialect::TanhOp> { | |||
} | |||
}; | |||
|
|||
class TileOpPattern : public pir::OpRewritePattern<paddle::dialect::TileOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此op应该是无限制进入trt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改!感谢~
slice_layer.set_input(2, output_shape_tensor) | ||
|
||
version_list = get_trt_version_list() | ||
if version_list >= [8, 6, 1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
大于8.6.0版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改!感谢~
def test_trt_result(self): | ||
self.check_trt_result() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再加一个rank==repeat_rank的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改!感谢~
|
||
|
||
@converter_registry.register("pd_op.share_data", trt_version="8.x") | ||
def where_converter(network, paddle_op, inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个converter的名字为什么叫"where_converter"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改!感谢~
@@ -81,5 +86,19 @@ def test_trt_result(self): | |||
self.check_marker(expected_result=False) | |||
|
|||
|
|||
class TestShare_DataTRTPatternCase1(TensorRTBaseTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestShare_DataTRTPatternCase1命名不符合规则
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改!感谢~
def setUp(self): | ||
self.python_api = paddle.tile | ||
self.api_args = { | ||
"x": np.random.randn(1, 2, 3).astype("int32"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类型改为int64
def setUp(self): | ||
self.python_api = api_wrapper | ||
self.api_args = { | ||
"x": np.random.rand(4, 3, 5).astype("float32"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再添加一个int64的单测
PR Category
Inference
PR Types
Others
Description
card-71500
add pir_converter tile & share_data