CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR+CINN]Support multi-thread Pre-Compile for Lowering FusionOp #62952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
} | ||
// Build and trigger compilaion cache. | ||
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget()); | ||
pir_compiler.Build(groups); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处统一多线程编程加速
using BroadcastTreeInfoMap = | ||
std::unordered_map<GroupPtr, | ||
std::shared_ptr<BroadcastTreeInfo>, | ||
SharedGroupHasher, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里目前复用了现有代码,但个人评估是有潜在hash安全风险的,下个PR将统一优化
paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.cc
Outdated
Show resolved
Hide resolved
paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.cc
Outdated
Show resolved
Hide resolved
using CacheValue = std::shared_ptr<pir::CompilationResult>; | ||
|
||
static CompilationCache& Instance() { | ||
static CompilationCache instance; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thread_local?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
懒汉式的单例实现采用的是static 静态变量,这一点是保证了线程安全的。感谢提醒,这里我发现cache_需要添加shared_mutex使其读写是线程安全的,更准确的说,应该要独立实现一个线程安全的uordered_map
CompilationCache() = default; | ||
CINN_DISALLOW_COPY_AND_ASSIGN(CompilationCache); | ||
|
||
std::unordered_map<size_t, CacheValue> cache_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个设计很奇怪。为什么把一个hash_value当作key?这是表示只要hash value一样,其CacheKey一定一样吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里理想态上应该有一层单独的数据结构,类似GroupInfo,其只是记录一个Group真正与Compiler相关的信息,而且不能记录Operation*(因为是面向子图的通用缓存,每个子图在Program析构之后,operation*都将是非法指针)。
关于GroupInfo的设计,是下个PR单独来设计实现的,故这里只是临时设计为了size_t。
|
||
struct PreAnalysisInfo { | ||
GroupInfoMap group_infos; | ||
BroadcastTreeInfoMap broadcast_tree_infos; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast tree
是用于broadcast类型子图编译时的局部概念,目前看在Analysis阶段好像没有用到broadcast tree
,这里如果不放到AnalysisInfo
中会不会对降低代码复杂度有帮助呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:BroadcastTree无需针对每个group进行构造,可以单独实现HasMultiBranch函数逻辑,命中后再按需构建BroadcastTree,如线下沟通,后面单独PR
void FusionOpAnalysis::GatherGroup(pir::Operation* fusion_op) { | ||
std::shared_ptr<Group> group_ptr = RebuildGroup(fusion_op, is_dy_shape_); | ||
VLOG(6) << "Gather Group " << group_ptr->FuncName() | ||
<< " for fusion_op : " << fusion_op->id(); | ||
pre_analysis_info_->group_infos.insert({fusion_op, group_ptr}); | ||
if (is_dy_shape_) { | ||
auto broadcast_tree_info = std::make_shared<BroadcastTreeInfo>(group_ptr); | ||
pre_analysis_info_->broadcast_tree_infos.insert( | ||
{group_ptr, broadcast_tree_info}); | ||
} | ||
} | ||
|
||
void FusionOpAnalysis::RunImpl(pir::Operation* op) { | ||
if (op->isa<cinn::dialect::FusionOp>()) { | ||
GatherGroup(op); | ||
return; | ||
} | ||
for (uint32_t i = 0; i < op->num_regions(); ++i) { | ||
for (auto& block : op->region(i)) { | ||
for (auto& op : block) { | ||
RunImpl(&op); | ||
} | ||
} | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些成员函数声明和实现同一个文件中的可以放在一起吗?现在这些函数需要多次滚动屏幕查找
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Pass已经非常大了,我计划后面把这个模块拆成目录
const auto& EnqueueGroup = [&](const GroupPtr& group) { | ||
const bool has_broadcast_tree = | ||
pre_analysis_info_->broadcast_tree_infos.count(group) > 0; | ||
if (has_broadcast_tree) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以直接看group的维度符号里是否有Broadcast来判断有没有broadcast分支,避免broadcast_tree的概念扩散到外部
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会单独提重构PR来优化
PR types
New features
PR changes
Others
Description
Pcard-67164