In May Nvidia shipped CuTe‑DSL, the Python library they teased at GTC earlier in the year that mirrors CUTLASS’s C++ tensor‑layout . Then, at the start of June, the ‑dev label disappeared (so presumably its production ready now). The pitch is simple: Write speed‑of‑light kernels from the comfort of Python.
Of course, nothing about CUDA is ever really simple. CuTe‑DSL gives the full Cutlass experience1, wrapped in an ever so slightly more approachable interface.
Getting Cute: Transpose
Matrix transpose felt like a reasonable ‘hello world’ : (B[j,i] = A[i,j]). The PyTorch version is simple: torch.transpose(input, 0, 1).
To get a baseline, here is a simple transpose kernel in Triton. We tl.load, flip the coordinates and tl.store it back.
@triton.jit
def triton_transpose_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE: tl.constexpr):
# 2D block coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Calculate offsets
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Load with masking
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
a = tl.load(input_ptr + offs_m[:, None] * N + offs_n[None, :], mask=mask)
# Store transposed (swap coordinates)
tl.store(output_ptr + offs_n[:, None] * M + offs_m[None, :], a, mask=mask)
Here’s the same idea in CuTe‑DSL. CuTe leverages a decorator and Pythons ability to integrate with JITs. Anything decorated with @jit runs host-side (on the CPU), while @kernel is used for device side (on the GPU). There are both AST based and tracing based options dependening on the presence of dynamic shapes or control flow.
@cute.kernel
def transpose_kernel(self, mA: cute.Tensor, mB: cute.Tensor):
tidx = cute.arch.thread_idx()[0]
tidy = cute.arch.thread_idx()[1]
bidx = cute.arch.block_idx()[0]
bidy = cute.arch.block_idx()[1]
# This might all be unnecessary
# but I was fearful of the compiler
tile_start_m = cutlass.Int32(0)
tile_start_n = cutlass.Int32(0)
global_m = cutlass.Int32(0)
global_n = cutlass.Int32(0)
M = cutlass.Int32(0)
N = cutlass.Int32(0)
val = cutlass.Float32(0.0)
# Calculate tile starting positions
tile_start_m = bidy * self._tile_size
tile_start_n = bidx * self._tile_size
# Calculate global coordinates for this thread
global_m = tile_start_m + tidy
global_n = tile_start_n + tidx
# Get matrix dimensions at runtime
M = mA.shape[0]
N = mA.shape[1]
# Bounds checking and transpose operation
if global_m < M and global_n < N:
val = mA[global_m, global_n]
# Transpose: B[n, m] = A[m, n]
mB[global_n, global_m] = val
What just happened?
- Thread and block indices come straight from CUDA (
thread_idx,block_idx), vs the Triton block abstraction - No explicit loads or stores: CuTe uses overloaded
[]to generate them.
Launching isn’t a million miles away from Triton:
@cute.jit # host side
def launch(self, A: cute.Tensor, B: cute.Tensor):
M, N = A.shape
grid = ((N + self.T - 1)//self.T,
(M + self.T - 1)//self.T, 1)
self.transpose_kernel(A, B).launch(
grid=grid,
block=[self.T, self.T, 1],
)
Because CuTe‑DSL speaks DLPack, you can hand it a PyTorch tensor directly. If you wanted to cache the conversion, it looks like this:
A_cute = from_dlpack(A).mark_layout_dynamic()
The mark_layout_dynamic is used to trigger the dynamic shape support, and avoid shape specialization. The one place where this went a bit funky in my testing was dealing with singular leading dimensions: there you need to be more explicit about the shape to satisfy the compiler.
Layouts and Memory
This kernel isn’t really leveraging the fundamental value of CuTe though, which is composable tensor layouts and memory management. CuTe‑DSL exposes the full memory hierarchy: global, shared, register (and tmem for those with blackwells), and lets you tile, copy, and pipeline data between them. Common primitives:
make_layout/make_layout_tv: describe how a tensor is laid out.cute.zipped_divide(tensor, tiler): tile a tensor.cute.copy(src_layout, dst_layout, pred=mask): async copy.cute.arch.sync_threads(): explicit barrier.
HGEMMony2
CuTe ships with some example kernels, so I grabbed one — an HGEMM (half-precision, FP16, batched GEMM) — and compared to an example implementation in Triton.
To express the same thing in PyTorch, we can unleash our inner Jeremy Howard and use einsum notation: torch.einsum("mkl,nkl->mnl", a, b). Take L batches of a MxK matrix, L batches of a NxK matrix, and return L batches of a MxN matrix.
Here is the Triton:
@triton.jit
def triton_batched_hgemm_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K, L,
stride_am, stride_ak, stride_al,
stride_bn, stride_bk, stride_bl,
stride_cm, stride_cn, stride_cl,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton batched half-precision GEMM kernel: C[m,n,l] = sum_k A[m,k,l] * B[n,k,l]"""
pid = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1) # Batch dimension
# Calculate batch offsets
batch_offset_a = pid_batch * stride_al
batch_offset_b = pid_batch * stride_bl
batch_offset_c = pid_batch * stride_cl
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
# Include batch offsets in pointer calculations
a_ptrs = a_ptr + batch_offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
# transpose for GEMM (load B[K, N] pattern)
b_ptrs = b_ptr + batch_offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# We accumulate into fp32 for higher accuracy.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, mask in K
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Convert back to FP16 for output
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + batch_offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
The core loop here is:
- Divide into
block_size_kchunks of K - For each do a masked load (so when we hit the edges we don’t bring in garbage)
- Do the dot product for the tile into an accumulator for the result matrix
- Advance the
AandBpointers
The CuTe kernel is, uhhh… a bit more involved. The full kernel is several hundred lines long. You can see the source in the Cutlass repo, which demonstrates some cool features like the ability to pass in an epilogue function for fusion.
For now, lets focus on the main loop. As before, we are looping over tiles of K3 .
The first thing we do is advance the pointers. The kernel is doing explicit pipelining of transfers from global memory, to shared memory, to registers, so we need to set wait groups do ensure all the loading has completed before we advance. We’re not actually doing loading in this section, just prepping the ground:
for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
for k_block in range(num_k_block):
if k_block == num_k_block - 1:
tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
cute.arch.cp_async_wait_group(num_smem_stages - 2)
cute.arch.sync_threads()
Next we kick off a copy from shared memory (smem) to registers (rmem) using cute.copy for the (future) A and B tiles.
k_block_next = (k_block + 1) % num_k_block # static
cute.copy(
tiled_copy_s2r_A,
tCsA_p[None, None, k_block_next],
tCrA_copy_view[None, None, k_block_next],
)
cute.copy(
tiled_copy_s2r_B,
tCsB_p[None, None, k_block_next],
tCrB_copy_view[None, None, k_block_next],
)
Finally, we interleave the transfers of the next A and B tiles from global to shared memory with the actual gemm operation (the equivalent of tl.dot). I will trust the folks at Nvidia that this is an optimal pattern. The pred= in there is the equivalent of masking in Triton.
if k_block == 0:
if k_tile + num_smem_stages - 1 < k_tile_count:
cute.copy(
tiled_copy_B,
tBgB[None, None, None, k_tile_index],
tBsB[None, None, None, smem_pipe_write],
pred=tBpB,
)
k_tile_index = k_tile_index + 1
cute.arch.cp_async_commit_group()
smem_pipe_write = smem_pipe_read
smem_pipe_read = smem_pipe_read + 1
if smem_pipe_read == num_smem_stages:
smem_pipe_read = 0
The pipelining is explicit, which is nice for debuggability and optimization, but very manual.
Debugging Tips
export CUTE_DSL_LOG_TO_CONSOLE=1 export CUTE_DSL_LOG_LEVEL=10 # up to 100 export CUTE_DSL_PRINT_IR=1 # dump MLIR
cute.printf()gives you a GPU‑side printf.- Kernels are aggressively cached; rm
~/.cache/cutedslif things look stale. - Multiple
@cute.jithost functions in the same Python scope can confuse MLIR (mainly for launching kernels). - The control‑flow rules are strict: no
returninside a kernel; initialize everything.
If you’re exploring GPU kernels for the first time, I strongly recommend starting with Triton. When you need to really get into the weeds, or want to reuse CUTLASS building blocks, its great to have CuTe‑DSL as an option in Python (provided you’re comfortable spelunking in GPU internals).