CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN]apply broadcast tree optimization #66537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN]apply broadcast tree optimization #66537
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
std::unordered_map<std::string, ::pir::Attribute> jit_kernel_attr = [&]() { | ||
const auto& optional_broadcast_tree = GetBroadcastTreeForOptimize(group); | ||
if (optional_broadcast_tree.has_value()) { | ||
const std::shared_ptr<BroadcastTree> broadcast_tree = | ||
optional_broadcast_tree.value(); | ||
const auto& value_to_dim_expr_idx = | ||
GetGroupDimExprInfo(group).value_to_dim_expr_idx; | ||
return CompileBroadcastTree( | ||
group, *broadcast_tree, value_to_dim_expr_idx); | ||
} else { | ||
return GetJitKernelAttr(group); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前PR暂时还是将Broadcast tree的编译放在ProcessDyShapeGroup中,下一个PR将直接放入GetJitKernelAttr函数的逻辑中并放入PreAnalysis增强系统可维护性
@@ -35,12 +35,8 @@ std::optional<std::shared_ptr<BroadcastTree>> GetBroadcastTreeForOptimize( | |||
bool ContainBroadcastShape(const common::BroadcastLeaf& leaves); | |||
GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group); | |||
|
|||
pir::Operation* CompileBroadcastTreeToConditionBlock( | |||
std::unordered_map<std::string, pir::Attribute> CompileBroadcastTree( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CompileBroadcastTreeToJitKernelOpAttrs
函数签名发生了变化,名字必须做到名副其实。
std::unordered_map<std::string, pir::Attribute>> | ||
CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) { | ||
std::unordered_map<std::string, pir::Attribute> | ||
CompileBroadcastGroupsAsOpAttribute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CompileBroadcastGroupsAsJitKernelOpAttribute
语义要表达到位
ops_mapper[op] = new_op; | ||
} | ||
|
||
const int& group_idx) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::string& name_suffix
这个字段的作用太弱了,其实本函数需要的只是名字后缀。到时候直接把std::to_string(group_idx)
传给本函数就行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,上述命名及变量问题都将在下一个PR中进行修改
* [CINN]apply broadcast tree optimization * fix * fix unit test
* [CINN]apply broadcast tree optimization * fix * fix unit test
PR Category
CINN
PR Types
Improvements
Description
Pcard-67164
This PR applies the optimization of broadcast tree lowering to host function.
To learn more about the theoretical explanation of lowering broadcast tree in host wrapper function and the step-by-step arrangement for the Pull Request (PR), please refer to the draft PR65604 for details.