CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
【Infer Symbolic Shape BUAA No.76】Add mean_all op #66609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for InferSymbolicShape interface
@@ -576,6 +576,28 @@ bool MinOpInferSymbolicShape(pir::Operation *op, | |||
return MaxOpInferSymbolicShape(op, infer_context); | |||
} | |||
|
|||
bool MeanAllOpInferSymbolicShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
添加OpTest单测,coverage 流水线未覆盖
test/legacy_test/test_mean_op.py
Outdated
@@ -60,6 +60,7 @@ def init_dtype_type(self): | |||
|
|||
def test_check_output(self): | |||
self.check_output(check_pir=True) | |||
self.check_output(check_symbol_infer=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行删掉,只需要调用一次check_output, 且check_pir 和 check_symbol_infer 都是默认开启的
test/legacy_test/test_mean_op.py
Outdated
@@ -43,7 +43,7 @@ def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False): | |||
|
|||
class TestMeanOp(OpTest): | |||
def setUp(self): | |||
self.op_type = "mean" | |||
self.op_type = "mean_all" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是mean的单测不能修改,可以copy一个mean_all的单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
test/legacy_test/test_mean_all_op.py
Outdated
@@ -0,0 +1,730 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mean_all 和mean 在底层是一个算子,不用分两个文件,在test_mean_op.py 里添加几个mean_all 的单测就行了
* mean_all * test_op * nothing * test mean_all * test mean * test mean * add mean_all optest
PR Category
CINN
PR Types
improvements
Description
添加mean_all算子符号推导接口