Natalia Gimelshein added cuda support for nonzero_static to PyTorch the other day — its a feature Jax has had for a while that lets you avoid a bit of data dependent annoyingness.
I most often see nonzero pop up in logs: its underlies a lot of boolean mask operations and torch.where. A downside of torch.nonzero is that we don’t know the size of the returned tensor. The shape is data dependent, which causes pain for the compiler, and when running on an accelerator requires a device to host sync. This can significantly slow down otherwise fast operations.
nonzero_static overcomes this by allowing you to supply a shape for the output — if the actual result is smaller it is padded, if larger the result is truncated. The PR linked above enables that for CUDA. You can’t seamlessly use it with bool masks or torch.where, but you can easily replace the call with one to nonzero_static.
To see the difference, imagine we have a big tensor where we have some sense of the output shape, for example searching for a one-hot index like this:
import torch
# Set device
device = torch.device("cuda:0")
# Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
size = 100_000_000
large_tensor = torch.zeros(size, device=device, dtype=torch.long)
large_tensor[size // 2] = 1 # Place a single 1 in the middle of the tensor
# Record GPU events to ensure asynchronicity
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
ones_indices = torch.nonzero(large_tensor)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"torch.nonzero() execution time: {elapsed_time_ms:.2f} ms")
print("Indices of 1s:", ones_indices.cpu())
That gives us (on my laptop)
torch.nonzero() execution time: 177.58 ms
Indices of 1s: tensor([[50000000]])
If we modify it to use nonzero_static:
import torch
# Set device
device = torch.device("cuda:0")
# Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
size = 100_000_000
large_tensor = torch.zeros(size, device=device, dtype=torch.long)
large_tensor[size // 2] = 1 # Place a single 1 in the middle of the tensor
# Record GPU events to ensure asynchronicity
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
ones_indices = torch.nonzero_static(large_tensor, size=1)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"torch.nonzero_static() execution time: {elapsed_time_ms:.2f} ms")
print("Indices of 1s:", ones_indices.cpu())
torch.nonzero_static() execution time: 78.53 ms
Indices of 1s: tensor([[50000000]])
A very nice improvement!