CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] Add InferSymbolicShapeInterface for (frame, prune_gate_by_capacity
) op
#68644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] Add InferSymbolicShapeInterface for (frame, prune_gate_by_capacity
) op
#68644
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Please resolve the conflict |
… infer_symbolic_shape_for_frame_prune_gate_by_capacity
@@ -74,7 +74,7 @@ def init_attrs(self): | |||
|
|||
def test_check_output(self): | |||
paddle.enable_static() | |||
self.check_output() | |||
self.check_output(check_pir=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么打开的pir的flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了符号推导检查,这个 paddle.signal.frame api 已经适配了 PIR,应该需要打开 pir 检查(感觉这个单测文件都需要打开 pir😂
Paddle/python/paddle/signal.py
Lines 143 to 145 in 0419d33
return _C_ops.frame(x, frame_length, hop_length, axis) | |
elif in_pir_mode(): | |
return _C_ops.frame(x, frame_length, hop_length, axis) |
bool contain_unknow_dim; | ||
for (size_t i = 0; i < x_shape.size(); i++) { | ||
if (x_shape[i].isa<int64_t>()) { | ||
contain_unknow_dim = true; | ||
break; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimExpr 能同时表示动态和静态维度,这里不需要区分是否包含动态维度,可以删掉这部分代码
break; | ||
} | ||
} | ||
if (x_shape_or_data.data().has_value() || contain_unknow_dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里条件判断改成判断 seq_length 是否是int64
n_frames = infer_context->GetNextSymName(); | ||
} else { | ||
n_frames = symbol::DimExpr( | ||
(seq_length.dyn_cast<int64_t>() - frame_length) / hop_length + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimExpr 可以直接运算,不需要dyn_cast 为 int64
int64_t n_worker = op->attribute<pir::Int64Attribute>("n_worker").data(); | ||
int64_t expert_count_num_ele = 1; | ||
for (const auto &i : expert_count_shape) { | ||
if (i.isa<int64_t>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimExpr可以直接运算
} else { | ||
PADDLE_THROW(::common::errors::InvalidArgument( | ||
"The shape of expert_count must be known.")); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉这部分代码
const auto &expert_count_shape = expert_count_shape_or_data.shape(); | ||
int64_t n_expert = op->attribute<pir::Int64Attribute>("n_expert").data(); | ||
int64_t n_worker = op->attribute<pir::Int64Attribute>("n_worker").data(); | ||
int64_t expert_count_num_ele = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expert_count_num_ele 声明成DimExpr吧
PADDLE_ENFORCE_EQ( | ||
expert_count_num_ele, | ||
n_expert * n_worker, | ||
common::errors::Unavailable( | ||
"The number of elements for expert_count is ( %ld ) incorrect. " | ||
"Because the number of expert_count must equal the " | ||
"product of n_worker ( %ld ) and n_expert ( %ld ). " | ||
"Please input appropriate expert_count again!", | ||
expert_count_num_ele, | ||
n_worker, | ||
n_expert)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ENFORCE 修改为 EqualConstrain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
if (!seq_length.isa<int64_t>()) { | ||
n_frames = infer_context->GetNextSymName(); | ||
} else { | ||
n_frames = symbol::DimExpr((seq_length - frame_length) / hop_length + 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该可以直接使用DimExpr进行计算,不用区分是否是int64_t吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce那里需要区分,这里不需要区分
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改~
Sorry to inform you that 8fce60d's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR Category
CINN
PR Types
Others
Description
Add InferSymbolicShapeInterface for (
frame, prune_gate_by_capacity
) op