You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In-place param.main_grad replaces the old master_grad in auto dy.
param.main_grad will use inplace add_ to save or cast grad to fp32 and store them in param.main_grad.
Enable by setting export Flags_enable_inplace_master_grad=1.
tensor_fusion
tensor_fusion groups params and grads into continuous param_storage and grad_storage.
grad_storage is used for grad's reduce_scatter comm.
param_storage is used for param's all_gather comm.
Supports non-uniform partitioning of params and grads across GPUs.
Each step requires get non-uniform params and grads from param_storage and grad_storage using view_slice.
Non-uniform grad_chip requires call all_reduce manually to collect global_norm_var.
Enable by setting export FLAGS_enable_tensor_fusion=1.
sharding_overlap
Overlap reduce_scatter comm for grads with grad computation in bwd.
Overlap all_gather comm for params with opt computation.
Enable by setting export FLAGS_enable_tensor_fusion=1.
Note: non-uniform tensor_fusion changes the order of add in grad_chip, introducing some loss diff.
Convergence results on llama7b, 1NC8, sharding8, 50,000 steps.
【TODO】Add strategy config in auto-dy, like hand-dy (feelt.init(strategy)) and auto-static (to_static(strategy)).
你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Category
Auto Parallel
PR Types
New features
Description
param.main_grad
replaces the oldmaster_grad
in auto dy.param.main_grad
will use inplaceadd_
to save or cast grad to fp32 and store them inparam.main_grad
.export Flags_enable_inplace_master_grad=1
.tensor_fusion
groups params and grads into continuousparam_storage
andgrad_storage
.grad_storage
is used for grad'sreduce_scatter
comm.param_storage
is used for param'sall_gather
comm.param_storage
andgrad_storage
usingview_slice
.grad_chip
requires callall_reduce
manually to collectglobal_norm_var
.export FLAGS_enable_tensor_fusion=1
.reduce_scatter
comm for grads with grad computation in bwd.all_gather
comm for params with opt computation.export FLAGS_enable_tensor_fusion=1
.Note: non-uniform
tensor_fusion
changes the order ofadd
ingrad_chip
, introducing some loss diff.Convergence results on llama7b, 1NC8, sharding8, 50,000 steps.
【TODO】Add strategy config in auto-dy, like hand-dy (feelt.init(strategy)) and auto-static (to_static(strategy)).
Pcard-70448