CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Typing][A-86, A-92] Add type annotations for python/paddle/vision/{datasets/cifar.py,image.py}
#65386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@86kkd 这几个文件,可以先参考一下 https://github.com/cattidea/paddlepaddle-stubs/tree/main/paddle-stubs/vision/datasets ~
主要是需要增加类属性,如:
class Cifar10(Dataset):
mode: Any = ...
backend: Any = ...
data_file: Any = ...
transform: Any = ...
dtype: Any = ...
不过,这里不能写 Any
,具体怎么写还需要看一下 ~
另外,@SigureMo ,python/paddle/vision/transforms/transforms.py
是否可以加点东西:
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from typing import overload, TypeVar, Any, Sequence, Callable, Tuple, Union
class Tensor: ...
class PILImage: ...
_DataT = TypeVar("_DataT", Tensor, PILImage, npt.NDArray[Any])
# ---------- transform ----------
_TransformTensor = Callable[[Tensor], Tensor]
_TransformTensors = Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]]
_TransformPILImage = Callable[[PILImage], PILImage]
_TransformPILImages = Callable[[Tuple[PILImage, ...]], Tuple[PILImage, ...]]
_TransformNDArray = Callable[[npt.NDArray[Any]], npt.NDArray[Any]]
_TransformNDArrays = Callable[
[Tuple[npt.NDArray[Any], ...]], Tuple[npt.NDArray[Any], ...]
]
_Transform = Union[
_TransformTensor,
_TransformTensors,
_TransformPILImage,
_TransformPILImages,
_TransformNDArray,
_TransformNDArrays,
]
# ---------- transform end ----------
class BaseTransform:
def __init__(self, keys=None):
self.keys = keys
@overload
def __call__(self, inputs: _DataT) -> _DataT: ...
@overload
def __call__(self, inputs: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...
def __call__(self, inputs):
if isinstance(inputs, (Tensor, PILImage, np.ndarray)):
return inputs
return tuple(inputs)
class Resize(BaseTransform):
def __init__(self, size, interpolation="bilinear", keys=None): ...
@overload
def __call__(self, inputs: _DataT) -> _DataT: ...
@overload
def __call__(self, inputs: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...
def __call__(self, inputs):
if isinstance(inputs, (Tensor, PILImage, np.ndarray)):
return inputs
return tuple(inputs)
class Compose:
transforms: Sequence[BaseTransform]
def __init__(self, transforms: Sequence[BaseTransform]) -> None:
self.transforms = transforms
@overload
def __call__(self, data: _DataT) -> _DataT: ...
@overload
def __call__(self, data: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...
def __call__(self, data):
if isinstance(data, (Tensor, PILImage, np.ndarray)):
return data
return tuple(data)
class Cifar10:
transform: _Transform
def __init__(self, transform: _Transform) -> None:
self.transform = transform
t1 = Resize(1)
cifar = Cifar10(transform=t1) # ok
t2 = Compose([Resize(1), Resize(2)])
cifar = Cifar10(transform=t2) # ok
def trans_pass(a: Tensor) -> Tensor:
return Tensor()
def trans_fail(a: Tensor) -> tuple[Tensor, Tensor]:
return Tensor(), Tensor()
cifar = Cifar10(transform=trans_pass) # ok
cifar = Cifar10(transform=trans_fail) # This should fail
添加以上 transform
内的东西 ~ 这样 datasets 里面的这几个地方,都可以用 _Transform
~
实际上,个人感觉,python/paddle/vision/transforms/transforms.py
在设计的时候,就应该有一个 Protocol
,而由于缺少这个 Protocol
,导致 BaseTransform
和 Compose
的 __call__
写法(输入参数名)都不一致 ~
backend=None, | ||
data_file: str | None = None, | ||
mode: str = 'train', | ||
transform: Callable[[(list | tuple)], BaseTransform] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是个什么写法?
python/paddle/vision/datasets/cifar.py
python/paddle/vision/{datasets/cifar.py, image.py}
python/paddle/vision/{datasets/cifar.py, image.py}
python/paddle/vision/{datasets/cifar.py,image.py}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…datasets/cifar.py,image.py}` (PaddlePaddle#65386) --------- Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
…datasets/cifar.py,image.py}` (PaddlePaddle#65386) --------- Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
PR Category
User Experience
PR Types
Improvements
Description
类型标注:
Related links
@SigureMo @megemini