CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR AMP]Gen AMP logic code in PIR APIs #61399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… refine-trace-amp
… gen-pir-amp-code
namespace paddle { | ||
namespace dialect { | ||
|
||
phi::DataType GetPromoteType( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件中很多逻辑和eager下的amp_utils.h处理逻辑是一样的,而且看起来是对每个op都有对应的逻辑,是否可以复用一下代码,做到一处修改静态图和动态图都生效?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以考虑复用下,可能需要把eager下的公共逻辑抽离出来,放到单独的某个目录,eager和Pir分别调用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为动态图和PIR下相关函数的输入类型不太一样,比如tensor和value,所以这里没有复用代码。不太确定使用模板是否能够达到复用一份代码的目的,后面我单独提个PR来尝试一下。包括杰哥提的修改的意见,后面我一并修改下~
const phi::DataType& amp_dtype) { | ||
auto dst_type = amp_dtype; | ||
// only consider the dtype of input(X). | ||
if (op_name == "batch_norm" || op_name == "layer_norm" || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以在文件的加一个匿名namespace,里面有一个const 的 unordered_set,代替这里的逻辑。这样此处的两个 if就可以用 & join起来了
line 28行也可以一起优化:
const auto& HandleSpecicalOp = [&](){....};
HandleSpecicalOp();
if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpDtype() == | ||
"float16") { | ||
if (op_name == "fused_attention") { | ||
for (size_t i = 0; i < amp_values_vector.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的for代码可以抽离出来一个lambda函数,放到 38行:
const auto& HandleFuseAttention = [&](){...};
"float16") { | ||
if (op_name == "fused_attention") { | ||
for (size_t i = 0; i < amp_values_vector.size(); i++) { | ||
if (i != 3 || i != 4 || i != 9 || i != 10) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码中应该尽量避免magic mumber,这里的3,4,9,10需要在lambda函数里定义为变量,比如
const size_t xxx_index = 3;
或者定义一个:
const unorder_set<size_t> skip_value_indexs = {/*xxx_index=*/ 3, ...}
} | ||
} | ||
} else if (op_name == "fused_feedforward") { | ||
for (size_t i = 0; i < amp_values_vector.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。另外应该将size操作放到循环外面,避免O(N)的调用,比如
const size_t value_length = amp_values_vector.size();
<< " input(" << input_name << " to dst_dtype(" | ||
<< phi::DataTypeToString(dst_dtype) << ")."; | ||
if ((op_name == "batch_norm" || op_name == "layer_norm" || | ||
op_name == "sync_batch_norm" || op_name == "weight_only_linear") && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
} | ||
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) { | ||
if (input_name == "LnScale" || input_name == "LnBias" || | ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以定义一个const set 在匿名空间
} | ||
|
||
if (use_promote) { | ||
if (paddle::imperative::AmpOperators::Instance() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最外层的 if-else 两个函数可以抽离一个lambda函数,减少嵌套
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment的优化思路,需要单独PR fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
Pcard-67164
本PR是PIR下支持AMP功能的第二个PR,本PR中主要完成了PIR API内部AMP代码逻辑的自动生成。