CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
在Upsample和interpolate函数中加入recompute_scale_factor参数 #71997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
modified: python/paddle/nn/functional/common.py modified: python/paddle/nn/layer/common.py new file: test/legacy_test/test_interp_recompute_scale_factor.py
modified: test/legacy_test/test_interp_recompute_scale_factor.py
你的PR提交成功,感谢你对开源项目的贡献! |
modified: python/paddle/nn/functional/common.py modified: python/paddle/nn/layer/common.py
modified: test/legacy_test/CMakeLists.txt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paconvert也加下测试,在paconvert/tests目录,这个会和pytorch的运行结果作对比
@@ -493,6 +500,15 @@ def interpolate( | |||
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" | |||
) | |||
|
|||
if ( | |||
recompute_scale_factor is not None | |||
and recompute_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几个条件语句可以优化下写法,使其更易读一些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,已修改。
modified: test/legacy_test/test_interp_recompute_scale_factor.py
modified: python/paddle/nn/functional/common.py
@@ -493,6 +500,11 @@ def interpolate( | |||
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" | |||
) | |||
|
|||
if recompute_scale_factor and size is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个放到下面的
if out_shape is not None:
分支下吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,已修改。
@@ -645,7 +657,7 @@ def _is_list_or_tuple_(data): | |||
attrs['out_h'] = out_shape[1] | |||
attrs['out_w'] = out_shape[2] | |||
|
|||
else: | |||
elif scale is not None and recompute_scale_factor is not True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分支逻辑梳理一下,是否是:
if out_shape is not None:
pass
elif scale is not None:
if recompute_scale_factor:
pass
else:
pass
else:
raise error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,已按此修改。
modified: python/paddle/nn/functional/common.py
@@ -550,8 +562,6 @@ def _is_list_or_tuple_(data): | |||
|
|||
out_shape = size | |||
scale = scale_factor | |||
if out_shape is not None and scale is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个为什么删除?
@@ -493,6 +500,11 @@ def interpolate( | |||
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" | |||
) | |||
|
|||
if recompute_scale_factor and size is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个并没有修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉,这个上面的看错了,已修改。
modified: python/paddle/nn/functional/common.py
test/legacy_test/CMakeLists.txt
Outdated
@@ -833,6 +833,7 @@ set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static | |||
set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) | |||
set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) | |||
set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) | |||
set_tests_properties(test_interp_recompute_scale_factor PROPERTIES TIMEOUT 120) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个耗时比较长,有办法精简下case缩短下时间不?
modified: test/legacy_test/CMakeLists.txt modified: test/legacy_test/test_interp_recompute_scale_factor.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@Qin-sx 中文文档也同步修改
您好,中文文档应该已经修改了 |
Sorry to inform you that 44860c7's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
请merge下develop,重新跑下CI |
PR Category
User Experience
PR Types
Improvements
Description
Upsample参数实际调用了interpolate参数。
参考Pytorch,recompute_scale_factor为True时计算output size,然后重置scale参数。
API文档修改:PaddlePaddle/docs#7205
PaConvert修改:PaddlePaddle/PaConvert#567