CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[Distributed] Improve efficiency of NaN checker #135414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135414
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ea492a5 with merge base 5f57be7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Base tensor is guaranteed to have 16-byte alignment, but a view into it does not have to be 🤔 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think, given the increased complexity of the kernel, it would be good to add a test that more carefully checks for cases where the NaN detector misses a NaN.
given we can't realistically afford to do a test where we loop through the indices of a large tensor and set each value to NaN exhaustively, do you think it makes sense to do some combination of (a) exhaustive testing on a small-medium tensor that still is large enough to exercise both the unrolled and suffix loops, (b) a test that sets a random index to NaN, so at least throughout many repetitions of CI we could expect a 'flaky' signal if we are missing certain values?
Re (a): yes, I can add a test that shmoo's through small-medium sizes (and data types). Re (b): Yep, I think the existing tests can be modified to support the randomness. |
@wconstab Test modified to cover wider size range. |
one more thing- could be a separate PR, but we are still missing fp8 iiuc. We should definitely cover this. If its convenient to add in this PR, it might make sense if it is yet one more case of the kernel template |
@wconstab can it be in a separate PR for cleanness reason? |
test/distributed/test_c10d_nccl.py
Outdated
# randomly pick an nan element | ||
i = random.randint(0, nan_tensor.size(0) - 1) | ||
j = random.randint(0, nan_tensor.size(1) - 1) | ||
nan_tensor[i, j] = float("nan") | ||
index = (i,) * len(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this appears to only put NaN values on the I diagonal. What about something like this?
index = tuple([randint(...) for _ in len(size)])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, adopted.
// EltPerPack would be greater than 8 if falling in this case. | ||
|
||
template <typename T, int EltPerPack> | ||
struct CheckBytePack { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc this generalized kernel would only be used for float8? i guess in a later PR, you would possibly replace this by a specialized one too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
int nWorkers = blockDim.x * gridDim.x; | ||
// First load values from global memory into tmp buffer | ||
#pragma unroll 8 | ||
for (int j = 0; j < UNROLL; j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, does checkChunk get called with a different ptr offset for each thread?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below at line 134, checkChunk
is called like this:
checkChunk<T>(ptr + offset);
offset
accounts for different offsets for different threads.
// We just do regular load and check | ||
for (; offset < sizeInBP; offset += blockDim.x * gridDim.x) { | ||
BytePack tmp = ptr[offset]; | ||
CheckBytePack<T, sizeof(BytePack)/sizeof(T)>::check(&tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confused, if we are sure we have enough data left for one call to CheckBytePack<T, B/T> doesn't that also mean we have enough data for a faster call to CheckBytePack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean why not a call to checkChunk?
The reason is that checkChunk
checks on 8*BytePack
in one call, while CheckBytePack
checks 1 BytePack
.
This slow loop here accounts for the case that we don't have 8*BytePack
left.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh- yes, i got confused between the two. i think this makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in summary, i think the algorithm is
Pre: always < 1
BytePack, since =1
would imply its already 16B aligned, so don't even use 'CheckByte', do a local check
Body: process chunks of 8 BytePack (e.g. 8*16 = 128B chunks) per call
Tail: since alignment is now guaranteed, just process (N < 8) 16B Bytepacks individually
I might understand wrong, but If "So our perf is on-par with torch ops.", why not just use torch.any(torch.isnan(x))? |
Good question. I had the same question too. So the reasoning goes like this: (backward)
Re why "we need to stop communication from spreading NaNs", here is a view from @wconstab : |
another flavor on this is, we could use it if we could easily modify it to trap() on nan, instead of asynchronously producing a bool tensor that someone (who?) has to check (when?). We definitely don't want to do a cuda synchronize after each nan check and check it on the cpu side. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Some customers would like to run the NaN checks on the fly, so we are improving its efficiency. ## Benchmarking Allreduce 2G floats. `TORCH_NCCL_NAN_CHECK=1` Red kernel: ncclAllreduce Blue kernel: Nan check <img width="1093" alt="Screenshot 2024-09-06 at 10 00 05 PM" src="https://github.com/user-attachments/assets/5501bc31-024f-4115-adb2-dd66eb4025d3"> ## Comparison with torch ops: Let's say a user manually check for NaNs with the following torch ops before all-reduce: ``` torch.any(torch.isnan(x)) ``` <img width="1091" alt="Screenshot 2024-09-06 at 10 14 53 PM" src="https://github.com/user-attachments/assets/1f8b5f63-c955-4612-bb96-241b6c69959b"> So our perf is on-par with torch ops. ## Changes - Load from vidmem using "big packs" of 16 bytes - Bump `blockDim.x` from 256 to 512 - Separate loads and checks into two loops, each of 8 iterations - Unroll the loops - Templated functions for checking NaN in a "big pack" based on dtype Special thanks to @jbachan from NCCL! Pull Request resolved: pytorch#135414 Approved by: https://github.com/wconstab
Some customers would like to run the NaN checks on the fly, so we are improving its efficiency.
Benchmarking
Allreduce 2G floats.
TORCH_NCCL_NAN_CHECK=1
Red kernel: ncclAllreduce
Blue kernel: Nan check
Comparison with torch ops:
Let's say a user manually check for NaNs with the following torch ops before all-reduce:
So our perf is on-par with torch ops.
Changes
blockDim.x
from 256 to 512cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Special thanks to @jbachan from NCCL!