CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Typing][B-34] Add type annotations for python/paddle/amp/auto_cast.py
#66119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件有点麻烦,辛苦再确认一下 ~ 🤟🤟🤟
python/paddle/amp/auto_cast.py
Outdated
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( | ||
DygraphShardingOptimizer, | ||
DygraphShardingOptimizerV2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要在这里导入了,按照原逻辑来吧 ~
python/paddle/amp/auto_cast.py
Outdated
from typing import Generator | ||
|
||
from paddle import Tensor, dtype | ||
from python.paddle.nn.layer.layers import Layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from python.paddle.nn.layer.layers import Layer | |
from paddle.nn import Layer |
python/paddle/amp/auto_cast.py
Outdated
): | ||
custom_white_list: _CustomList, | ||
custom_black_list: _CustomList, | ||
level: AMP_LEVEL = 'O1', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
level: AMP_LEVEL = 'O1', | |
level: str = 'O1', |
AMP_LEVEL
有点像枚举 ~
python/paddle/amp/auto_cast.py
Outdated
custom_white_list: _CustomList, | ||
custom_black_list: _CustomList, | ||
level: AMP_LEVEL = 'O1', | ||
dtype: dtype = 'float16', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype: dtype = 'float16', | |
dtype: str = 'float16', |
这里是作为 white_list()
的 key,应该只能是 str
~
python/paddle/amp/auto_cast.py
Outdated
""" | ||
Judge whether current custom device support bfloat16 amp. | ||
""" | ||
place = _current_expected_place() | ||
return place.get_device_type() == 'npu' | ||
|
||
|
||
def need_keep_fp32(layer, dtype): | ||
def need_keep_fp32(layer: _LayerList, dtype: dtype) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def need_keep_fp32(layer: _LayerList, dtype: dtype) -> bool: | |
def need_keep_fp32(layer: _LayerList, dtype: str) -> bool: |
后面
elif (layer._dtype == 'float16') or (
(dtype == 'float16')
做字符串比对,应该只能是 str ~
python/paddle/amp/auto_cast.py
Outdated
self._save_dtype = save_dtype | ||
|
||
def __call__(self, state_dict): | ||
def __call__(self, state_dict: dict[Any, Any]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __call__(self, state_dict: dict[Any, Any]) -> None: | |
def __call__(self, state_dict: _StateDict) -> None: |
从 layer.register_state_dict_hook
中标注的内容
def register_state_dict_hook(
self, hook: _StateDictHook
) -> HookRemoveHelper:
...
_StateDict = Union[Dict[str, Tensor], typing.OrderedDict[str, Tensor]]
_StateDictHook = Callable[[_StateDict], None]
这里可以把 paddle/nn/layer/layers.py
中的 _StateDict
导进来直接用 ~~~
python/paddle/amp/auto_cast.py
Outdated
_OptimizerBase: TypeAlias = Union[ | ||
paddle.optimizer.Optimizer, | ||
DygraphShardingOptimizer, | ||
DygraphShardingOptimizerV2, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_OptimizerBase: TypeAlias = Union[ | |
paddle.optimizer.Optimizer, | |
DygraphShardingOptimizer, | |
DygraphShardingOptimizerV2, | |
] | |
class _OptimizerLike(Protocol): | |
def minimize( | |
self, | |
loss: Tensor, | |
startup_program: Program, | |
parameters: list[Tensor], | |
no_grad_set: set[Tensor], | |
) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]: ... | |
def step(self) -> None: ... | |
def set_state_dict(self, state_dict: dict[str, Tensor]) -> None: ... | |
def clear_grad(self, set_to_zero: bool) -> None: ... |
这里定义一个 Protocol ,然后后面都用这个就行 ~
最好的办法是 Optimizer 和这里的 DygraphShardingOptimizer 都有一个公共父类,既然没有,我们定义一个 Protocol 暂时解决一下吧 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@enkilee _OptimizerLike
改了一下,减少限制,可以试一下 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
python/paddle/amp/auto_cast.py
Outdated
def amp_decorate( | ||
models, | ||
optimizers=None, | ||
level='O1', | ||
dtype='float16', | ||
master_weight=None, | ||
save_dtype=None, | ||
master_grad=False, | ||
excluded_layers=None, | ||
): | ||
models: Layer | list[Layer], | ||
optimizers: _OptimizerList | None = None, | ||
level: AMP_LEVEL = 'O1', | ||
dtype: dtype = 'float16', | ||
master_weight: bool | None = None, | ||
save_dtype: dtype | None = None, | ||
master_grad: bool = False, | ||
excluded_layers: Layer | list[Layer] | None = None, | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以用 overload
_ModelsT = TypeVar("_ModelsT", Layer, List[Layer])
_OptimizersT = TypeVar("_OptimizersT", _OptimizerBase, List[_OptimizerBase])
@overload
def amp_decorate(
models: _ModelsT,
optimizers: _OptimizersT = ...,
level: str = ...,
dtype: str = ...,
master_weight: bool | None = ...,
save_dtype: str | None = ...,
master_grad: bool = ...,
excluded_layers: Layer | list[Layer] | list[type[Layer]] | None= ...,
) -> tuple[_ModelsT, _OptimizersT]: ...
@overload
def amp_decorate(
models: _ModelsT,
optimizers: Literal[None] = ...,
level: str = ...,
dtype: str = ...,
master_weight: bool | None = ...,
save_dtype: str | None = ...,
master_grad: bool = ...,
excluded_layers: Layer | list[Layer] | list[type[Layer]] | None= ...,
) -> _ModelsT: ...
@dygraph_only
def amp_decorate(
models: _ModelsT,
optimizers: _OptimizersT | None = None,
level: str = 'O1',
dtype: str = 'float16',
master_weight: bool | None = None,
save_dtype: str | None = None,
master_grad: bool = False,
excluded_layers: Layer | list[Layer] | list[type[Layer]] | None= None,
) -> tuple[_ModelsT, _OptimizersT] | _ModelsT :
python/paddle/amp/auto_cast.py
Outdated
enable: bool = True, | ||
custom_white_list: _CustomList | None = None, | ||
custom_black_list: _CustomList | None = None, | ||
level: AMP_LEVEL = 'O1', | ||
dtype: dtype = 'float16', | ||
use_promote: bool = True, | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable: bool = True, | |
custom_white_list: _CustomList | None = None, | |
custom_black_list: _CustomList | None = None, | |
level: AMP_LEVEL = 'O1', | |
dtype: dtype = 'float16', | |
use_promote: bool = True, | |
) -> None: | |
enable: bool = True, | |
custom_white_list: _CustomList | None = None, | |
custom_black_list: _CustomList | None = None, | |
level: str = 'O1', | |
dtype: str = 'float16', | |
use_promote: bool = True, | |
) -> ContextManager: |
python/paddle/amp/auto_cast.py
Outdated
models: Layer | list[Layer], | ||
optimizers: _OptimizerList | None = None, | ||
level: AMP_LEVEL = 'O1', | ||
dtype: dtype = 'float16', | ||
master_weight: bool | None = None, | ||
save_dtype: dtype | None = None, | ||
master_grad: bool = False, | ||
excluded_layers: Layer | list[Layer] | None = None, | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考上面的 amp_decorate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…py` (PaddlePaddle#66119) --------- Co-authored-by: SigureMo <sigure.qaq@gmail.com>
PR Category
User Experience
PR Types
Improvements
Description
B34