Profiling Triton

There are a couple of different options to profile a Triton kernel.

Proton

Proton is the profiler that ships with Triton (profiler for triton). You can enable it and (optionally) activate/deactive around specific regions you want to profile. You have the ability to annotate functions with specific metrics as well.

 session = proton.start()  # Start profiling session

bias = torch.rand((256,), device='cuda', dtype=torch.float16)  # Bias vector
flops = 2 * M * N * K
bytes_accessed = A.element_size() * M*K + B.element_size() * K*N + C.element_size() * M*N  # rough bytes
    with proton.scope(f"fused_gemm_bias_relu [M={M}, N={N}, K={K}]", {"flops": flops, "bytes": bytes_accessed}):
        fused_gemm_bias_relu[grid](  
            A, B, C, bias, 
            M, N, K, 
            A.stride(0), A.stride(1), B.stride(0), B.stride(1), C.stride(0), C.stride(1),
            BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K
        )

 proton.finalize() 

The output can be visualized with the built-in viewer:

proton-viewer -m time/ms,tflop/s ./proton.hatchet

0.040 6.645 ROOT
├─ 0.004 nan _ZN2at6native55_GLOBAL__N__11f7a751_22_DistributionUniform_[...]_15PhiloxCudaStateESH_SI_
└─ 0.037 7.284 fused_gemm_bias_relu [M=1024, N=256, K=512]
└─ 0.037 nan fused_gemm_bias_relu

In this case you can see both the (trimmed!) generated name for the bias tensor set up as well as the name of my custom kernel.

nsight-compute

Nvidia also have a good range of tools for looking at performance. Note will need to enable access to the counters on device for this:

NVIDIA Development Tools Solutions – ERR_NVGPUCTRPERM: Permission issue with Performance Counters | NVIDIA Developer

On the offchance you’re doing this on WSL, https://peterchng.com/blog/2024/03/02/profiling-cuda-programs-on-wsl-2/ walks through the set up!

Nvidia ships nsight system which tracks larger system wide metrics, and nsight compute which is more focused on profiling execution. You can run it against a script like so:

ncu -o profile_results python test.py

The tool comes with a nice GUI for inspecting the results. It can show you the PTX or SASS source for the kernels, offers metrics like actively used registers (good for checking on register spilling), and calls out warnings on poor utilization or memory clashes.

Upcoming intra-kernel profiler

[tool][proton] Intra kernel profiling support by fywkevin · Pull Request #4861 · triton-lang/triton

There is an extension coming for Proton that enables profiling within kernels. This reserves a pre-allocated buffer on device and logs metrics locally, for reading out at the end of the execution. It outputs as a chrome trace for use within a wide range of dev tools. While this isn’t merged into mainline yet, you can see an example of the usage in the dev repo.

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading