CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Inference]Adapt generic plugin for PIR #66634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
CHECK_EQ(phi_kernel_contexts_[data_type]->OutputsSize(), getNbOutputs()); | ||
(*phi_kernels_[data_type])(phi_kernel_contexts_[data_type].get()); | ||
|
||
if (op_name_ == "pd_op.argsort") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得,从supportsFormatCombination到enqueue,都需要针对pd_op的类型做一些特定处理,这些特定处理的操作是否应该单独放置?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以单独放置,需要对这部分代码重新抽象设计
input_numel * data_type_and_size.second, | ||
place)); | ||
(*dense_tensor_inputs_)[i] = | ||
std::move(phi::DenseTensor(input_alloc, input_meta)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个input_tensor的拷贝,是必须的吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不是拷贝
nvinfer1::DataType data_type; | ||
// input | ||
if (op_name_ == "pd_op.embedding") { | ||
data_type = input_desc[1].type; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
获取data_type的方法,在这里比较简单粗暴,实际情况应该比这个复杂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里逻辑沿用的是老的generic_plugin,这里建议暂时先保持原状,后续如果遇到问题继续改进
} else { | ||
data_type = input_desc[0].type; | ||
} | ||
CHECK((data_type == nvinfer1::DataType::kFLOAT) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是否过于严格?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先对齐老设计,后续功能扩展再进一步放松
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for serialize
PR Category
Inference
PR Types
New features
Description
Pcard-71500
TensorRT通用Plugin适配PIR,本PR所做核心工作如下:
1,动态shape InferMeta支持在PIR下运行,并开发单独的注册功能,本PR完成3个算子InferMeta的适配
2,通用Plugin序列化及反序列化功能开发
3,通用Plugin执行enqueue功能开发,支持PIR下可变Attribute算子执行
4,PIRGenericPluginCreator模块开发,支持通过creator直接创建通用Plugin
5,通用Plugin基础组件重构,代码目录优化设计,支持不同PIR算子在通用plugin中实现各自的plugin功能函数
6,单测开发