CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
API improvement paddle.nanmedian 易用性提升 #62624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
out->set_dims(common::make_ddim(out_dim)); | ||
auto median_dim = out_dim; | ||
median_dim.push_back(2); | ||
median_index->set_dtype(DataType::INT64); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index的shape是不是不对,如果是取avg用法,则index的size是out的两倍,如果是取min用法,则index的shape应与out一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为avg模式下index的size是out的两倍,min模式下与out一致
python/paddle/tensor/stat.py
Outdated
Tensor, results of median along ``axis`` of ``x``. The output dtype is the same as `x`. | ||
((Tensor, Tensor), optional) | ||
If ``mode`` == 'avg', the result will be the tensor of median values; | ||
If ``mode`` == 'min' and ``axis`` is None or tuple, the result will be the tensor of median values; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分为两种:
- axis为int
- axis不为int(为None或tuple)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
((Tensor, Tensor), optional) | ||
If ``mode`` == 'avg', the result will be the tensor of median values; | ||
If ``mode`` == 'min' and ``axis`` is None or tuple, the result will be the tensor of median values; | ||
If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tuple情况下,也是返回1个吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,tuple,list,None都返回1个
python/paddle/tensor/stat.py
Outdated
If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors | ||
containing median values and their indices. | ||
|
||
When ``mode`` == 'avg', if data type of ``x`` is float64, data type of median values will be float64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是和输入x的类型一致吧,和paddle.median不太一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是和输入x的类型一致,已修改
python/paddle/tensor/stat.py
Outdated
name (str, optional): Name for the operation (optional, default is None). | ||
For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, results of median along ``axis`` of ``x``. The output dtype is the same as `x`. | ||
((Tensor, Tensor), optional) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些文字要写 nanmedian index value,不要照抄 paddle.median
的文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
|
||
When ``mode`` == 'avg', if data type of ``x`` is float64, data type of median values will be float64, | ||
otherwise data type of median values will be float32. | ||
When ``mode`` == 'min', the data type of median values will be the same as ``x``. The data type of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要照抄 paddle.median
文档
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
python/paddle/tensor/stat.py
Outdated
>>> print(y5_index.numpy()) | ||
[1 1 1] | ||
|
||
>>> y6, y6_index = x.nanmedian(1, mode='min') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
举一个axis为None的例子,此时无index
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加
python/paddle/tensor/stat.py
Outdated
attrs=attrs, | ||
) | ||
indices.stop_gradient = True | ||
if mode == 'min' and need_index: | ||
return out, indices[..., 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能否在infermeta就对index设置一个与out一致的shape,不用浪费一半的操作,mode='avg'时的index其实没用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改infermeta和kernel中对应处理的地方
int64_t pre_dim) { | ||
CUDA_KERNEL_LOOP(index, pre_dim) { | ||
int64_t offset = index * stride; | ||
printf("index: %d\n", index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
打印信息可以删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
python/paddle/tensor/stat.py
Outdated
If ``mode`` == 'min' and ``axis`` is int, the result will be a tuple of two tensors | ||
containing nanmedian values and their indices. | ||
|
||
When ``mode`` == 'avg', the output dtype is the same as `x`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_type应该不用额外强调,out与x一致,index是int64也是一般常识
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,那这里就和修改之前一样直接写 nanmedian values 的 data_type 与 x 一致了
python/paddle/tensor/stat.py
Outdated
When ``mode`` == 'min', the data type of nanmedian values will be the same as ``x``. | ||
If indices are retured, the data type will be int64. | ||
|
||
The data type of nanmedian values is the same as ``x``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个直接删掉也可以吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
python/paddle/tensor/stat.py
Outdated
If ``mode`` == 'avg', the result will be the tensor of nanmedian values; | ||
If ``mode`` == 'min' and ``axis`` is not int (None or list or tuple), the result will | ||
be the tensor of nanmedian values; | ||
If ``mode`` == 'min' and ``axis`` is int, the result will be a tuple of two tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样写吧:
If mode
== 'min' and axis
is int, the result will be a tuple of two tensors (nanmedian value and nanmedian index). Otherwise, only nanmedian value will be returned.
看起来简介易懂点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@NKNaN 中文文档也同时修改一下 |
PR types
New features
PR changes
OPs
Description
增加参数
mode
:默认值 'avg' ,另外可取 'min' 。当所需要计算的 tensor 在axis
轴上有偶数个元素时, 'avg' 表示计算结果为中间两个数的算术平均值;'min' 则为二者的最小值。返回值:当
mode = 'min'
且axis
不为 None 或 tuple 时,返回值为 (median_values, median_indices) ;其他情况下返回值为 median_values 的 tensor 。参考