You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pcard-85872
when use AMP(BF16 dtype, FP16 is fine) training in dygraph with pir will report an error: ValueError: (InvalidArgument) The type of data we are trying to retrieve (float32) does not match the type of data (bfloat16) currently contained in the container.
The reason is that when with pir(FLAGS_enable_pir_api=True) the method of _is_dtype_fp16_or_bf16 in the optimizer base class Optimizer incorrectly judges data type of BF16 (FP16 is fine). so when data type is actually BF16 , but it judges that not.
this further leads to issues of judgment logic of the method _add_moments_pows、_create_accumulators and _append_optimize_op in the optimizer. _add_moments_pows creates tensors of moment1, moment2, beta1_pow, and beta2_pow with dtype of BF16 (which should be FP32), which leads to an type mismatch error above when executing beta1_pow.data<MPDType() in the kernel of adam/adamw (adamw_kernel.cu).
你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Category
User Experience
PR Types
Bug fixes
Description
pcard-85872
when use AMP(BF16 dtype, FP16 is fine) training in dygraph with pir will report an error:
ValueError: (InvalidArgument) The type of data we are trying to retrieve (float32) does not match the type of data (bfloat16) currently contained in the container.
The reason is that when with pir(
FLAGS_enable_pir_api=True
) the method of_is_dtype_fp16_or_bf16
in the optimizer base classOptimizer
incorrectly judges data type of BF16 (FP16 is fine). so when data type is actually BF16 , but it judges that not.this further leads to issues of judgment logic of the method
_add_moments_pows
、_create_accumulators
and_append_optimize_op
in the optimizer._add_moments_pows
creates tensors ofmoment1
,moment2
,beta1_pow
, andbeta2_pow
with dtype of BF16 (which should be FP32), which leads to an type mismatch error above when executingbeta1_pow.data<MPDType()
in the kernel ofadam/adamw
(adamw_kernel.cu).