CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Typing][B-01] Add type annotations for python/paddle/io/reader.py
#65587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
python/paddle/io/reader.py
Outdated
@@ -382,25 +397,37 @@ class DataLoader: | |||
please see :code:`paddle.io.IterableDataset` | |||
""" | |||
|
|||
return_list: bool | |||
collate_fn: default_collate_fn | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要是个 Callable ~ 可以参考以下 default_collate_fn 的输入输出类型,貌似还挺复杂的 ... ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
恩,我试着改改。
python/paddle/io/reader.py
Outdated
collate_fn: default_collate_fn | None | ||
use_buffer_reader: bool | ||
prefetch_factor: int | ||
worker_init_fn: Callable[..., Any] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考了一下测试用例:
class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
def test_main(self):
place = base.CPUPlace()
with base.dygraph.guard(place):
dataset = RangeIterableDataset(0, 10)
def worker_spliter(worker_id):
worker_info = get_worker_info()
dataset = worker_info.dataset
start = dataset.start
end = dataset.end
num_per_worker = int(
math.ceil((end - start) / float(worker_info.num_workers))
)
worker_id = worker_info.id
dataset.start = start + worker_id * num_per_worker
dataset.end = min(dataset.start + num_per_worker, end)
dataloader = DataLoader(
dataset,
places=place,
num_workers=1,
batch_size=1,
drop_last=True,
worker_init_fn=worker_spliter,
)
rets = []
for d in dataloader:
rets.append(d.numpy()[0][0])
assert tuple(sorted(rets)) == tuple(range(0, 10))
这里貌似可以是 Callable[[int], None]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
python/paddle/io/reader.py
Outdated
prefetch_factor: int | ||
worker_init_fn: Callable[..., Any] | None | ||
dataset: Dataset | ||
feed_list: list[Tensor] | tuple[Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用 Sequence 可否?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,谢谢
python/paddle/io/reader.py
Outdated
worker_init_fn: Callable[..., Any] | None | ||
dataset: Dataset | ||
feed_list: list[Tensor] | tuple[Tensor] | ||
places: list[Place] | tuple[Place] | list[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用 paddle._typing 的 PlaceLike ? Sequence[PlaceLike]
Paddle 的 Place
与 CPUPlace
没有继承关系,所以不能直接用 Place
~
p.s. 我也是最近才发现的 🤣
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,谢谢
python/paddle/io/reader.py
Outdated
import numpy.typing as npt | ||
|
||
from paddle import Tensor | ||
from python.paddle._typing.device_like import PlaceLike |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from python.paddle._typing.device_like import PlaceLike | |
from python.paddle._typing import PlaceLike |
python/paddle/io/reader.py
Outdated
@@ -382,25 +404,37 @@ class DataLoader: | |||
please see :code:`paddle.io.IterableDataset` | |||
""" | |||
|
|||
return_list: bool | |||
collate_fn: Callable[[_Collate_Fn_State], None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个确实有点麻烦,我试了一下:
from __future__ import annotations
import numbers
import numpy as np
import numpy.typing as npt
from typing import Protocol, overload, Sequence, Mapping, Any, TYPE_CHECKING, AnyStr, TypeVar
import paddle
if TYPE_CHECKING:
from paddle import Tensor
KT = TypeVar('KT')
VT = TypeVar('VT')
@overload
def default_collate_fn(batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number]) -> npt.NDArray[Any]: ...
@overload
def default_collate_fn(batch: Sequence[Tensor]) -> Tensor: ...
@overload
def default_collate_fn(batch: Sequence[AnyStr]) -> AnyStr: ...
@overload
def default_collate_fn(batch: Sequence[Mapping[KT, VT]]) -> Mapping[KT, VT]: ...
@overload
def default_collate_fn(batch: Sequence[Sequence[VT]]) -> Sequence[VT]: ...
def default_collate_fn(batch):
sample = batch[0]
if isinstance(sample, np.ndarray):
batch = np.stack(batch, axis=0)
return batch
elif isinstance(sample, paddle.Tensor):
return paddle.stack(batch, axis=0)
elif isinstance(sample, numbers.Number):
batch = np.array(batch)
return batch
elif isinstance(sample, (str, bytes)):
return batch
elif isinstance(sample, Mapping):
return {
key: default_collate_fn([d[key] for d in batch]) for key in sample
}
elif isinstance(sample, Sequence):
sample_fields_num = len(sample)
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
raise RuntimeError(
"fields number not same among samples in a batch"
)
return [default_collate_fn(fields) for fields in zip(*batch)]
raise TypeError(
"batch data con only contains: tensor, numpy.ndarray, "
f"dict, list, number, but got {type(sample)}"
)
class _Collate_Fn(Protocol):
@overload
def __call__(self, batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number]) -> npt.NDArray[Any]: ...
@overload
def __call__(self, batch: Sequence[Tensor]) -> Tensor: ...
@overload
def __call__(self, batch: Sequence[AnyStr]) -> AnyStr: ...
@overload
def __call__(self, batch: Sequence[Mapping[KT, VT]]) -> Mapping[KT, VT]: ...
@overload
def __call__(self, batch: Sequence[Sequence[VT]]) -> Sequence[VT]: ...
fn: _Collate_Fn = default_collate_fn
@SigureMo 帮忙看看这个 _Collate_Fn
的 Protocol 看看是否可行?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是把default_collate_fn在这里又重载了?得这么重复么。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目测没啥问题,检查如果没问题就没问题,只有一些风格的小问题
_Collate_Fn
->_CollateFn
KT
->_K
,VT
->_V
,K 和 V 和 T 一样是常用的泛型变量
关于 _CollateFn
类型要用 Protocol 重写一遍的问题,没有解决方案,前段时间刚吐槽过一次,但就是得写两遍重复的代码(paddle.jit.save
那边就是这样的)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是把default_collate_fn在这里又重载了?得这么重复么。
抱歉,刚看到这里有留言 ... ...
我这里是举例,根据 default_collate_fn
可以推导出 _CollateFn
的写法 ~
只保留 _CollateFn
就可以了 ~ 不需要把 default_collate_fn
再写一边 ~~~ 🤣🤣🤣
python/paddle/io/reader.py
Outdated
worker_init_fn: Callable[[int], None] | ||
dataset: Dataset | ||
feed_list: Sequence[Tensor] | None | ||
places: Sequence[PlaceLike] | list[str] | None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
places: Sequence[PlaceLike] | list[str] | None | |
places: Sequence[PlaceLike] | None |
应该可以不需要 list[str]
了吧 ~ PlaceLike 包括 str 了 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外,DataLoader
的
def __len__(self):
def __iter__(self):
def __call__(self):
需要也标注一下 ~ 不然在 for 循环等地方可能无法确定类型 ~
python/paddle/io/reader.py
Outdated
|
||
from paddle import Tensor | ||
from python.paddle._typing.device_like import PlaceLike | ||
from python.paddle.io.dataloader.dataset import Dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哪来的 python.
?
python/paddle/io/reader.py
Outdated
@@ -382,25 +404,37 @@ class DataLoader: | |||
please see :code:`paddle.io.IterableDataset` | |||
""" | |||
|
|||
return_list: bool | |||
collate_fn: Callable[[_Collate_Fn_State], None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目测没啥问题,检查如果没问题就没问题,只有一些风格的小问题
_Collate_Fn
->_CollateFn
KT
->_K
,VT
->_V
,K 和 V 和 T 一样是常用的泛型变量
关于 _CollateFn
类型要用 Protocol 重写一遍的问题,没有解决方案,前段时间刚吐槽过一次,但就是得写两遍重复的代码(paddle.jit.save
那边就是这样的)
python/paddle/io/reader.py
Outdated
batch_size: int = 1, | ||
shuffle: bool = False, | ||
drop_last: bool = False, | ||
collate_fn: Callable[[_CollateFn], None] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(这里是不是 _CollateFn | None
呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collate_fn: Callable[[_CollateFn], None] | None = None,
?
collate_fn: _CollateFn | None = None
python/paddle/io/reader.py
Outdated
@overload | ||
def default_collate_fn( | ||
batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number], | ||
) -> npt.NDArray[Any]: | ||
... | ||
|
||
|
||
@overload | ||
def default_collate_fn(batch: Sequence[Tensor]) -> Tensor: | ||
... | ||
|
||
|
||
@overload | ||
def default_collate_fn(batch: Sequence[AnyStr]) -> AnyStr: | ||
... | ||
|
||
|
||
@overload | ||
def default_collate_fn(batch: Sequence[Mapping[_K, _V]]) -> Mapping[_K, _V]: | ||
... | ||
|
||
|
||
@overload | ||
def default_collate_fn(batch: Sequence[Sequence[_V]]) -> Sequence[_V]: | ||
... | ||
|
||
|
||
def default_collate_fn(batch): | ||
sample = batch[0] | ||
if isinstance(sample, np.ndarray): | ||
batch = np.stack(batch, axis=0) | ||
return batch | ||
elif isinstance(sample, paddle.Tensor): | ||
return paddle.stack(batch, axis=0) | ||
elif isinstance(sample, numbers.Number): | ||
batch = np.array(batch) | ||
return batch | ||
elif isinstance(sample, (str, bytes)): | ||
return batch | ||
elif isinstance(sample, Mapping): | ||
return { | ||
key: default_collate_fn([d[key] for d in batch]) for key in sample | ||
} | ||
elif isinstance(sample, Sequence): | ||
sample_fields_num = len(sample) | ||
if not all(len(sample) == sample_fields_num for sample in iter(batch)): | ||
raise RuntimeError( | ||
"fields number not same among samples in a batch" | ||
) | ||
return [default_collate_fn(fields) for fields in zip(*batch)] | ||
|
||
raise TypeError( | ||
"batch data con only contains: tensor, numpy.ndarray, " | ||
f"dict, list, number, but got {type(sample)}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用这里再写一边,如果有需要,在这个函数的文件那里写就行 ~ 🤣
Update
这个函数在 python/paddle/io/dataloader/collate.py
~ 不是公开 API ,先不标注了 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
python/paddle/io/reader.py
Outdated
... | ||
|
||
|
||
fn: _CollateFn = default_collate_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也不用 ~ 保留上面的 _CollateFn
就行 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他的暂时没发现啥问题了 ~ 辛苦 ~ 🤟🤟🤟
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…addlePaddle#65587) --------- Co-authored-by: SigureMo <sigure.qaq@gmail.com>
PR Category
User Experience
PR Types
Improvements
Description
B-01