CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
API Improvement: fix paddle.median 易用性提升 #64444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/tensor/stat.py
Outdated
@@ -463,6 +463,28 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): | |||
>>> print(median_indices) | |||
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, | |||
[1, 1, 1]) | |||
|
|||
>>> # cases containing nan values | |||
>>> x = paddle.to_tensor(np.array([[1,2,3,float('nan')],[1,2,3,4],[float('nan'),1,2,3]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个x不是float64的Tensor吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,这里是想加一下如果输入里有nan的例子
@@ -230,6 +289,47 @@ def test_index_odd_case(self): | |||
np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0]) | |||
np.testing.assert_equal(index.numpy(), [4, 4, 4]) | |||
|
|||
def test_nan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能运行的case不是这个int32/int64的case吗,这个和nan的关系是?另外这里也没有看到有测int32/int64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int32/int64添加在上面 test_median_static 和 test_median_dygraph 中测了。
int32/int64不能运行的是因为我之前添加 min 分支的时候加在了处理 nan 的这部分之前,在没有nan的情况下int32/int64输入在这里就会出错,out_tensor 是 int32/int64 类型而后面的 sum 是cast成了float64
现在的修改是让 min 和 avg 分支分别处理 nan 的情况:avg 保持之前的处理逻辑,输入是float32时输出是float32,其他情况输出是float64;min 在这个地方改了一下,cast的dtype改成了x.dtype,让输入输出的数据类型保持一致,同时如果要输出index的话也加了对index的相应处理。
@@ -521,6 +543,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): | |||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
第525行dtype的设置,这个设置其实不太合理,仅放在avg分支下吧,不影响min的分支
@@ -538,12 +565,29 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): | |||
out_idx = paddle.slice( | |||
idx, axes=[axis], starts=[kth], ends=[kth + 1] | |||
) | |||
# if contain nan on axis, return nan for that axis | |||
out_tensor = out_tensor + paddle.sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最后这一个 astype(x.dtype) 不需要吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.sum在输入是int32时输出会变成int64,最后这个 astype(x.dtype)是针对int32这种情况
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/sum_cn.html
axis=axis, | ||
keepdim=True, | ||
).astype(x.dtype) | ||
if need_idx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于nan的问题就先不大改吧,按之前的逻辑来,主要是适配dtype的影响
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这里的 if need_idx 分支是对输入有nan且需要输出index的情况处理,需要删掉吗?如果要删掉的话就是输入有nan的时候不输出index这样?目前torch的median输入有nan的时候会输出index,之前添加min分支的时候没有考虑这个情况,所以这里想补一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那就补上吧
@@ -164,6 +182,40 @@ def test_median_exception(self): | |||
self.assertRaises(ValueError, paddle.median, x, 2, False, 'max') | |||
self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) | |||
|
|||
def test_nan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独专门测一下int32/int64吧,关于nan的问题先不用深究了,保持之前的逻辑就行
第525行dtype的设置为fp32/fp64,这个设置其实不太合理,仅放在avg分支下吧。这个改一下 |
PR-CI-Static-Check有示例代码错误的问题,要修一下 |
index_along_axis = paddle.argsort( | ||
x_all_zero, axis=axis, stable=True | ||
) | ||
nan_index = paddle.sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果有多个nan,取paddle.sum好像也会出问题吧。多个nan应该按第一个nan的坐标来计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
已修改 |
index_along_axis * x_isnan, axis=axis, keepdim=True | ||
) | ||
nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) | ||
out_idx = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以简化下写法:
out_idx = out_idx * paddle.logical_not(nan_index_mask) + nan_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
index_along_axis * x_isnan, axis=axis, keepdim=True | ||
) | ||
nan_index_mask = paddle.sum(x_isnan, axis=axis, keepdim=True) | ||
out_idx = out_idx * paddle.logical_not(nan_index_mask) + nan_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_idx = out_idx * paddle.logical_not(nan_index_mask).astype('int64') + nan_index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@NKNaN 需要提交对应的中文文档 |
PR Category
User Experience
PR Types
Bug fixes
Description
修复 paddle.median 在min分支下不支持输入为除浮点类型以外的类型。
由于paddle.topk不支持bool类型,如果不单独处理bool类型输入,avg和min分支都不能支持,是否需要对bool类型添加支持?