CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
paddle.linalg.matrix_rank add attributes atol
and rtol
易用性提升
#66929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/tensor/linalg.py
Outdated
if use_atol_rtol: | ||
if rtol is None: | ||
rtol = full([], 0.0, x.dtype) | ||
if (atol is None) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在use_atol_rtol下,rtol is None了,atol就不可能为None了吧。use_default_tol放这里比较怪,改成use_default_atol?或者直接在kernel里判断atol=0,为0的话就视作use_default_atol。对于tol也将参数名统一为use_default_atol
看起来应该这么写?
if rtol is None:
rtol = full([], 0.0, x.dtype)
if atol is None:
atol = full([], 0.0, x.dtype)
use_default_atol = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if rtol is None: | ||
rtol = full([], 0.0, x.dtype) | ||
use_default_tol = True | ||
if atol is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是atol为None时,use_default_tol=True
吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 pytorch 的设计应该是 rtol 为 None 时,如果 atol 为 None 或被设为 0 就用默认 tol 值,如果 atol 设为大于 0 的值则 rtol 为 0。
相当于这样:
atol_input | rtol_input | atol | rtol |
---|---|---|---|
None 或 = 0.0 | None | 0.0 | max(m,n)*eps |
非 None, > 0.0 | None | 非 None, > 0.0 | 0.0 |
None | 非 None | 0.0 | 非 None |
非 None | 非 None | 非 None | 非 None |
所以 matrix_rank_atol_rtol 里面现在改成:
if (use_default_tol) {
判断 atol 是否为 0,
atol 为 0 则 rtol 用默认值, tol = max(m,n)*eps*sigma_1; // atol==0 对应 atol_input = None 或 = 0.0,rtol_input = None
atol 不为 0 则因为 rtol 此时是0, tol = max{atol, rtol*sigma_1} = atol; // 此时对应 atol_input 非 None, > 0.0,rtol_input = None
} else {
tol = max{atol, rtol*sigma_1}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry to inform you that e4ccdf1's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@NKNaN 需修改一下CI |
if rtol is None: | ||
rtol = full([], 0.0, x.dtype) | ||
use_default_tol = True | ||
if atol is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得这段话的意思是:如果设置了atol为大于0的值,那么rtol默认就是0,如果没有设置atol(则atol默认为0),此时rtol默认是epsilon() * max(rows, cols)
。所以应该根据atol来决定rtol的默认值,rtol的默认值不是自己决定的,而是atol决定的。
- atol/rtol均不为None,则根据实际值
- atol不为None,rtol为None(此时rtol默认值为0)
- atol为None(默认为0),rtol为None(此时rtol默认值为eps*max)
- atol为None(默认为0),rtol不为None
所以use_default_tol是决定选2还是3,也就是标记atol是否为None的。
看下面的tol也是这个逻辑,tol与atol是别名关系。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段话要放到注释里么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看下面的tol也是这个逻辑,tol与atol是别名关系。
如果要让语义上下保持一致的话确实 use_default_tol=True 放在 atol 为 None 的分支里合适一点。我这样写是相当于让 2 和 3 都进入 kernel 里的 use_default_tol 分支,然后根据 atol 是否为 0 来判断 rtol 是否取 eps* max 还是 0,atol 为 0 可能是atol取默认值或者用户设置的atol是 0。
那我把 use_default_tol=True 放在 atol 为 None 的分支里再改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段话要放到注释里么?
Notes 的第 4 条是这段话
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得 atol、rtol 的输入情况是这样:
- atol/rtol均不为None,则根据实际值
- atol不为None且大于0,rtol为None(此时rtol默认值为0)
- atol不为None且等于0,rtol为None(此时rtol默认值为eps*max)
- atol为None(默认为0),rtol为None(此时rtol默认值为eps*max)
- atol为None(默认为0),rtol不为None
感觉把 use_default_tol=True 放在 atol 为 None 的分支里会有问题,首先 use_default_tol=True 只能放在 4 下,不能把 4 和 5 合在一起,因为如果在kernel中判断 rtol 是否为 0 将会无法区分 rtol 的输入是 None 还是输入就是 0,也就是无法区分下面这两种情况:
x = torch.ones(size=[3, 4, 5, 5])
atol = None
rtol = None
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]])
# 此时 kernel 中 rtol=max(m,n)*eps
x = torch.ones(size=[3, 4, 5, 5])
atol = None
rtol = 0.0
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[5, 5, 5, 5],
# [5, 5, 5, 5],
# [5, 5, 5, 5]])
# 此时 kernel 中 rtol=0.0
# 理论上矩阵的 rank 值都应该是 1 ,但由于这个api的设计是取大于max(atol, rtol*sigma_1)的特征值的数量,
# 因为存在计算精度的问题,所以结果都是 5。
也没法把 3 和 4 在python端合并,在python端判断 atol==0,因为有 atol 可以取某些值为 0 的 Tensor。
所以 use_default_tol=True 只能放在 4 下,1、2、3、5 在 kernel 里面进入 use_default_tol=False 分支。
而 3 对应的计算过程可以看作是
rtol = paddle.where(atol==0, eps*max(m,n), rtol)
这样的话 3 就不能和 1、2、5 放在一起,因为同理在 kernel 中判断 atol 是否为 0 也包含 atol 在python端的输入是 None 和输入本身就是 0。比如这样就无法区分下面这三种情况:
x = torch.ones(size=[3, 4, 5, 5])
atol = 0.0
rtol = None
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]])
# atol=0, rtol=None 属于情况 3
# 此时 kernel 中 rtol=max(m,n)*eps
x = torch.ones(size=[3, 4, 5, 5])
atol = None
rtol = 0.0
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[5, 5, 5, 5],
# [5, 5, 5, 5],
# [5, 5, 5, 5]])
# atol=None, rtol=0 属于情况 5
# 此时 kernel 中 rtol=0.0
x = torch.ones(size=[3, 4, 5, 5])
atol = 0.0
rtol = 0.0
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[5, 5, 5, 5],
# [5, 5, 5, 5],
# [5, 5, 5, 5]])
# atol=0.0, rtol=0.0 属于情况 1
# 此时 kernel 中 rtol=0.0
所以在 kernel 里面,rtol 取的默认值是 0 还是 eps*max 的确是由 atol 决定的,但是它需要限定在 python 端 rtol=None 的情况,也就是在 python 端需要将 2、3、4 合并起来看。即 rtol=None 的情况下控制 kernel 进入同一个分支,在分支里面进一步由 atol 的值来决定 rtol 的默认值。
其次 use_default_tol 这个语义的话,我还是觉得由于它是用在 kernel 里面的,所以这个 tol 应该指的是 tol = max(atol, rtol*sigma_1)
在 kernel 里的最终计算结果,所以如果让这个变量来控制 kernel 进入 2、3、4 分支的话我觉得也能说通,因为只有 rtol=None 时 tol 的最终计算结果才有可能是 default_tol 即 eps*max*sigma_1
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前主要问题是0和None是否算一种情况,是不是简单点,对于atol来说,0就是None,None就是0:
- 如果rtol为None
- atol>0,则设为0
- atol=0/None,则设为eps*max
- 如果rtol不为None
- atol>0,则维持原设置即可
- atol=0/None
- 如果rtol是被设为正数,则维持原设置即可
- 如果rtol被设为0,出现双0的情况,此时会有争议,是应该将rtol设为eps*max,还是就按双0的情形?
只有一种争议情况,解决了这个问题就可以想清楚,是否还需要加use_default这些标志位
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果rtol被设为0时,atol也被设为0,是按双0进行的计算
x = torch.ones(size=[3, 4, 5, 5])
atol = 0.0
rtol = 0.0
res = torch.linalg.matrix_rank(x, atol, rtol, hermitian=True)
# tensor([[5, 5, 5, 5],
# [5, 5, 5, 5],
# [5, 5, 5, 5]])
# atol=0.0, rtol=0.0 属于情况 1
# 此时 kernel 中 rtol=0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前主要问题是0和None是否算一种情况,是不是简单点,对于atol来说,0就是None,None就是0:
如果rtol为None
- atol>0,则设为0
- atol=0/None,则设为eps*max
如果rtol不为None
atol>0,则维持原设置即可
atol=0/None
- 如果rtol是被设为正数,则维持原设置即可
- 如果rtol被设为0,出现双0的情况,此时会有争议,是应该将rtol设为eps*max,还是就按双0的情形?
只有一种争议情况,解决了这个问题就可以想清楚,是否还需要加use_default这些标志位
那看来,最后可以简化为只有2种情况了,无需标志位use_default这些:
首先atol为None时直接设为0,然后判断rtol:
- 如果rtol为None,则其将先设为0,然后在在C++中判断atol是否也为0,如果为0则再将设为eps*max。
- 如果rtol不为None,则无需任何判断与修改。
跟torch也能对上了吧,应该就是最简化的实现版本了吧,也不用标志位这些,可以把API/yaml/kernel都再简化一版。
Api-Benchmark 提示性能下降,是不是只用 tol 的分支保持以前的写法不变好一点? |
以前的分支如果更新到新分支有影响,可以维持旧分支不变 |
Sorry to inform you that 3248f18's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@NKNaN 看下这个方案做怎么样:
|
是的,可以和torch对上,我再简化一下kernel。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,按这一版实现吧,无不兼容影响,且新版kernel设计更简洁,支持atol、rtol两个参数的使用
PR Category
User Experience
PR Types
Improvements
Description
目的:增加
atol
和rtol
参数。修改:新增 matrix_rank_atol_rtol op,增加输入 atol、rtol。
#66070 (comment)