CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Typing][B-28] Add type annotations for python/paddle/distribution/uniform.py
#65660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/distribution/uniform.py
def __init__(self, low, high, name=None): | ||
def __init__( | ||
self, | ||
low: float | list | tuple | np.ndarray | Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意下,我们不允许 list、tuple 这种泛型不写泛型参数的写法,这里 list
、tuple
、np.ndarray
都是泛型
对于 list
、tuple
,一般常用 Sequence[T]
来作为输入
对于 np.ndarray
,一般常用 npt.NDArray[Any]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -100,7 +106,12 @@ class Uniform(distribution.Distribution): | |||
[0.50000000]) | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否有公开的类属性可以标注下呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就像 paddle.distribution.Beta 那样给 Uniform 的类属性加个 low 和 high 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯~
low: float | ||
| Sequence[float] | ||
| npt.NDArray[np.float32 | np.float64] | ||
| Tensor, | ||
high: float | ||
| Sequence[float] | ||
| npt.NDArray[np.float32 | np.float64] | ||
| Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
low: float | |
| Sequence[float] | |
| npt.NDArray[np.float32 | np.float64] | |
| Tensor, | |
high: float | |
| Sequence[float] | |
| npt.NDArray[np.float32 | np.float64] | |
| Tensor, | |
low: ( | |
float | |
| Sequence[float] | |
| npt.NDArray[np.float32 | np.float64] | |
| Tensor | |
), | |
high: ( | |
float | |
| Sequence[float] | |
| npt.NDArray[np.float32 | np.float64] | |
| Tensor | |
), |
这里我们统一使用这种代码风格~
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
low: float | Sequence[float] | npt.NDArray[np.float32 | np.float64] | Tensor | ||
high: float | Sequence[float] | npt.NDArray[ | ||
np.float32 | np.float64 | ||
] | Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
刚验证了一下,Uniform
实例中的 low high
应该只能是 Tensor
,与下面 __init__
输入的 low high
应该不是一个东西,可以看作输入输出的关系:
In [6]: import paddle
...: from paddle.distribution import Uniform
...: paddle.seed(2023)
...:
...: # Without broadcasting, a single uniform distribution [3, 4]:
...: u1 = Uniform(low=3.0, high=4.0)
In [7]: u1.low
Out[7]:
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
3.)
实例初始化的时候,会把不同类型的 low high
统一转换为 Tensor
,因此,实例中的 low high
也只能是 Tensor
~
帮忙看看是不是这样?
另外,Uniform
实际上还有几个属性,比如:
self.all_arg_is_float = False
self.batch_size_unknown = False
self.name = name if name is not None else 'Uniform'
self.dtype = 'float32'
只是,感觉当初设计成私有属性会比较好,所以加不加我觉得无所谓,如果遇到其他地方多注意一下就好 ~ 😶🌫️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实例初始化的时候,会把不同类型的
low high
统一转换为Tensor
,因此,实例中的low high
也只能是Tensor
~帮忙看看是不是这样?
嗯嗯,是这样子的。所以这里添加类属性的意义是想表示 所有的实例都会有这个属性,起一个提示(初始化之后的)类型的作用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,比如以下这段代码:
import paddle
import paddle.distribution
beta = paddle.distribution.Uniform(0.3, 0.4)
reveal_type(beta.low)
使用 mypy 进行检查:
> mypy --show-traceback --config-file=pyproject.toml test_tmp.py
在标注类属性之前,结果是:
test_tmp.py:5:13: note: Revealed type is "Any"
mypy 或者其他工具如果没有发现类中单独标注属性类型,则会通过
标注类属性之后,结果是:
test_tmp.py:5:13: note: Revealed type is "paddle.tensor.tensor.Tensor"
这样,后面的代码便以此推导的类型进行检查 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
了解啦
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯嗯,是这样子的。所以这里添加类属性的意义是想表示 所有的实例都会有这个属性,起一个提示(初始化之后的)类型的作用?
嗯嗯,我这里补充下,这里提示的是实例属性的类型,而不是类本身属性的类型
这里标注的是 self.xxx = yyy
绑定后这个属性可能的类型
class Foo:
xxx: int | str
def __init__(self, x):
if x:
self.xxx = 1
else:
self.xxx = ""
当然静态检查工具可能可以分析出这个结果,但这需要「分析」的时间成本,而且也可能因为「实现」不同导致静态分析结果不一致,因此我们提供确定的类型信息会更合适些~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
User Experience
PR Types
Improvements
Description