CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[MPS][BE] Do not create 4 instances of FUSED_ADAM_OPS
#141090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
đź”— Helpful Linksđź§Ş See artifacts and rendered test results at hud.pytorch.org/pr/141090
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 34e00fc with merge base 0443398 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot merge -f "Mac builds + lint looks fine" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #141092 Approved by: https://github.com/Skylion007, https://github.com/kulinseth ghstack dependencies: #141089, #141090
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions Pull Request resolved: #141103 Approved by: https://github.com/kulinseth ghstack dependencies: #141089, #141090, #141092
For MacOS14+ Running following script ```python ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: #141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #141089, #141090, #141092, #141103
For MacOS14+ Running following script (adapted from one mentioned in #127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: #141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #141089, #141090, #141092, #141103
Defining `static char shaderSource[]` in the header will instantiate it as often as it is included. Solved the problem by renaming `static auto getCPLState(const std::string&)` into `auto getFusedAdamCPLState(const std::string&)` and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint) I.e. before ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357744 Nov 19 17:58 lib/libtorch_cpu.dylib ``` and afer ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357120 Nov 19 17:57 lib/libtorch_cpu.dylib ``` Pull Request resolved: pytorch#141090 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#141089
Pull Request resolved: pytorch#141092 Approved by: https://github.com/Skylion007, https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions Pull Request resolved: pytorch#141103 Approved by: https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092
For MacOS14+ Running following script ```python ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
For MacOS14+ Running following script (adapted from one mentioned in pytorch#127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
Defining `static char shaderSource[]` in the header will instantiate it as often as it is included. Solved the problem by renaming `static auto getCPLState(const std::string&)` into `auto getFusedAdamCPLState(const std::string&)` and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint) I.e. before ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357744 Nov 19 17:58 lib/libtorch_cpu.dylib ``` and afer ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357120 Nov 19 17:57 lib/libtorch_cpu.dylib ``` Pull Request resolved: pytorch#141090 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#141089
Pull Request resolved: pytorch#141092 Approved by: https://github.com/Skylion007, https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090
Instead of calling `REGISTER_FUSED_ADAM_OP` macro with 7 parameters 16 times, 4 type parameter macros for each op and then one op to define the quartet of ops: Adam, AdamW and their grad functions Pull Request resolved: pytorch#141103 Approved by: https://github.com/kulinseth ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092
For MacOS14+ Running following script ```python ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
For MacOS14+ Running following script (adapted from one mentioned in pytorch#127242 ) ```python import torch from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device, dtype = "mps", torch.bfloat16 results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)] fn = adamw.adamw if adamWflag else adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label=f'Fused Adam on {device} using {dtype}', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Produces following results on M4Pro running MacOS 15 ``` [-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------] | Fused: True | Fused: False 1 threads: ---------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10 | 283 | 2810 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10 | 277 | 2430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10 | 285 | 2400 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10 | 278 | 2250 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10 | 504 | 2700 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10 | 478 | 2600 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10 | 506 | 2500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10 | 482 | 2300 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10 | 2089 | 4190 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10 | 1940 | 3800 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10 | 2100 | 3770 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10 | 1950 | 3600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50 | 842 | 14000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50 | 835 | 11800 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50 | 845 | 11700 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50 | 855 | 11000 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50 | 1410 | 14000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50 | 1350 | 12000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50 | 1400 | 12000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50 | 1340 | 11000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50 | 9767 | 20400 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50 | 8991 | 18600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50 | 9803 | 18300 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50 | 9070 | 17600 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 27000 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 1600 | 24100 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 23500 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 1600 | 21800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 2740 | 26000 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 2580 | 24000 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 2730 | 25000 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 2600 | 23000 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 19350 | 39000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 17780 | 37300 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 19400 | 37000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 17900 | 35500 Times are in microseconds (us). ``` Pull Request resolved: pytorch#141104 Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: pytorch#141089, pytorch#141090, pytorch#141092, pytorch#141103
Defining `static char shaderSource[]` in the header will instantiate it as often as it is included Solved the problem by renaming `static auto getCPLState(const std::string&)` into `auto getFusedAdamCPLState(const std::string&)` and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint) I.e. before ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357744 Nov 19 17:58 lib/libtorch_cpu.dylib ``` and afer ``` % ls -lak lib/libtorch_cpu.dylib -rwxr-xr-x 1 malfet staff 183357120 Nov 19 17:57 lib/libtorch_cpu.dylib ``` ghstack-source-id: 4f4a97c Pull Request resolved: pytorch/pytorch#141090
Stack from ghstack (oldest at bottom):
FUSED_ADAM_OPS
 #141090Defining
static char shaderSource[]
in the header will instantiate it as often as it is included.Solved the problem by renaming
static auto getCPLState(const std::string&)
intoauto getFusedAdamCPLState(const std::string&)
and instantiating it only once resulted in 500K reduction in binary size (and perhaps even more in runtime footprint)I.e. before
and afer