Cute-DSL

In May Nvidia shipped CuTe‑DSL, the Python library they teased at GTC earlier in the year that mirrors CUTLASS’s C++ tensor‑layout . Then, at the start of June, the ‑dev label disappeared (so presumably its production ready now). The pitch is simple: Write speed‑of‑light kernels from the comfort of Python.

Of course, nothing about CUDA is ever really simple. CuTe‑DSL gives the full Cutlass experience1, wrapped in an ever so slightly more approachable interface.

Getting Cute: Transpose

Matrix transpose felt like a reasonable ‘hello world’ : (B[j,i] = A[i,j]). The PyTorch version is simple: torch.transpose(input, 0, 1).

To get a baseline, here is a simple transpose kernel in Triton. We tl.load, flip the coordinates and tl.store it back.

@triton.jit
def triton_transpose_kernel(input_ptr, output_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    # 2D block coordinates
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    # Calculate offsets
    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    
    # Load with masking
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    a = tl.load(input_ptr + offs_m[:, None] * N + offs_n[None, :], mask=mask)
    
    # Store transposed (swap coordinates)
    tl.store(output_ptr + offs_n[:, None] * M + offs_m[None, :], a, mask=mask)

Here’s the same idea in CuTe‑DSL. CuTe leverages a decorator and Pythons ability to integrate with JITs. Anything decorated with @jit runs host-side (on the CPU), while @kernel is used for device side (on the GPU). There are both AST based and tracing based options dependening on the presence of dynamic shapes or control flow.

    @cute.kernel
    def transpose_kernel(self, mA: cute.Tensor, mB: cute.Tensor):
        tidx = cute.arch.thread_idx()[0]
        tidy = cute.arch.thread_idx()[1]
        bidx = cute.arch.block_idx()[0]
        bidy = cute.arch.block_idx()[1]
        # This might all be unnecessary
        # but I was fearful of the compiler
        tile_start_m = cutlass.Int32(0)
        tile_start_n = cutlass.Int32(0)
        global_m = cutlass.Int32(0)
        global_n = cutlass.Int32(0)
        M = cutlass.Int32(0)
        N = cutlass.Int32(0)
        val = cutlass.Float32(0.0)
        # Calculate tile starting positions
        tile_start_m = bidy * self._tile_size
        tile_start_n = bidx * self._tile_size
        # Calculate global coordinates for this thread
        global_m = tile_start_m + tidy
        global_n = tile_start_n + tidx
        # Get matrix dimensions at runtime
        M = mA.shape[0]
        N = mA.shape[1]
        # Bounds checking and transpose operation
        if global_m < M and global_n < N:
            val = mA[global_m, global_n]
            # Transpose: B[n, m] = A[m, n]
            mB[global_n, global_m] = val

What just happened?

  • Thread and block indices come straight from CUDA (thread_idx, block_idx), vs the Triton block abstraction
  • No explicit loads or stores: CuTe uses overloaded [] to generate them.

Launching isn’t a million miles away from Triton:

@cute.jit   # host side
def launch(self, A: cute.Tensor, B: cute.Tensor):
    M, N = A.shape
    grid = ((N + self.T - 1)//self.T,
            (M + self.T - 1)//self.T, 1)
    self.transpose_kernel(A, B).launch(
        grid=grid,
        block=[self.T, self.T, 1],
    )

Because CuTe‑DSL speaks DLPack, you can hand it a PyTorch tensor directly. If you wanted to cache the conversion, it looks like this:

A_cute = from_dlpack(A).mark_layout_dynamic()

The mark_layout_dynamic is used to trigger the dynamic shape support, and avoid shape specialization. The one place where this went a bit funky in my testing was dealing with singular leading dimensions: there you need to be more explicit about the shape to satisfy the compiler.

Layouts and Memory

This kernel isn’t really leveraging the fundamental value of CuTe though, which is composable tensor layouts and memory management. CuTe‑DSL exposes the full memory hierarchy: global, shared, register (and tmem for those with blackwells), and lets you tile, copy, and pipeline data between them. Common primitives:

  • make_layout / make_layout_tv: describe how a tensor is laid out.
  • cute.zipped_divide(tensor, tiler): tile a tensor.
  • cute.copy(src_layout, dst_layout, pred=mask): async copy.
  • cute.arch.sync_threads(): explicit barrier.

HGEMMony2

CuTe ships with some example kernels, so I grabbed one — an HGEMM (half-precision, FP16, batched GEMM) — and compared to an example implementation in Triton.

To express the same thing in  PyTorch, we can unleash our inner Jeremy Howard and use einsum notation: torch.einsum("mkl,nkl->mnl", a, b).  Take L batches of a MxK matrix, L batches of a NxK matrix, and return L batches of a MxN matrix.

Here is the Triton:

@triton.jit
def triton_batched_hgemm_kernel(
	a_ptr, b_ptr, c_ptr,
  M, N, K, L, 
  stride_am, stride_ak, stride_al, 
  stride_bn, stride_bk, stride_bl, 
  stride_cm, stride_cn, stride_cl, 
  BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  GROUP_SIZE_M: tl.constexpr,
):
    """Triton batched half-precision GEMM kernel: C[m,n,l] = sum_k A[m,k,l] * B[n,k,l]"""
    pid = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)  # Batch dimension
    
    # Calculate batch offsets
    batch_offset_a = pid_batch * stride_al
    batch_offset_b = pid_batch * stride_bl  
    batch_offset_c = pid_batch * stride_cl
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    
    # Include batch offsets in pointer calculations
    a_ptrs = a_ptr + batch_offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    # transpose for GEMM (load B[K, N] pattern)
    b_ptrs = b_ptr + batch_offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    
    # We accumulate into fp32 for higher accuracy.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, mask in K
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # Convert back to FP16 for output
    c = accumulator.to(tl.float16)
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + batch_offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

The core loop here is:

  • Divide into block_size_k chunks of K
  • For each do a masked load (so when we hit the edges we don’t bring in garbage)
  • Do the dot product for the tile into an accumulator for the result matrix
  • Advance the A and B pointers

The CuTe kernel is, uhhh… a bit more involved. The full kernel is several hundred lines long. You can see the source in the Cutlass repo, which demonstrates some cool features like the ability to pass in an epilogue function for fusion.

For now, lets focus on the main loop. As before, we are looping over tiles of K3 .

The first thing we do is advance the pointers. The kernel is doing explicit pipelining of transfers from global memory, to shared memory, to registers, so we need to set wait groups do ensure all the loading has completed before we advance. We’re not actually doing loading in this section, just prepping the ground:

for k_tile in cutlass.range_dynamic(k_tile_count, unroll=1):
	for k_block in range(num_k_block):  
      if k_block == num_k_block - 1:
    	tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read]
      tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read]
      cute.arch.cp_async_wait_group(num_smem_stages - 2)
      cute.arch.sync_threads()

Next we kick off a copy from shared memory (smem) to registers (rmem) using cute.copy for the (future) A and B tiles.

k_block_next = (k_block + 1) % num_k_block  # static
cute.copy(
	tiled_copy_s2r_A,
  	tCsA_p[None, None, k_block_next],
    tCrA_copy_view[None, None, k_block_next],
   )
cute.copy(
	tiled_copy_s2r_B,
    tCsB_p[None, None, k_block_next],
     tCrB_copy_view[None, None, k_block_next],
  )

Finally, we interleave the transfers of the next A and B tiles from global to shared memory with the actual gemm operation (the equivalent of tl.dot). I will trust the folks at Nvidia that this is an optimal pattern. The pred= in there is the equivalent of masking in Triton.

   if k_block == 0:
 		if k_tile + num_smem_stages - 1 < k_tile_count:
    	    cute.copy(
				tiled_copy_B,
				tBgB[None, None, None, k_tile_index],
				tBsB[None, None, None, smem_pipe_write],
				pred=tBpB,
			)
        k_tile_index = k_tile_index + 1
        cute.arch.cp_async_commit_group()
		smem_pipe_write = smem_pipe_read
		smem_pipe_read = smem_pipe_read + 1
			
		if smem_pipe_read == num_smem_stages:
      	    smem_pipe_read = 0

The pipelining is explicit, which is nice for debuggability and optimization, but very manual.

Debugging Tips

export CUTE_DSL_LOG_TO_CONSOLE=1
export CUTE_DSL_LOG_LEVEL=10   # up to 100
export CUTE_DSL_PRINT_IR=1     # dump MLIR
  • cute.printf() gives you a GPU‑side printf.
  • Kernels are aggressively cached; rm ~/.cache/cutedsl if things look stale.
  • Multiple @cute.jit host functions in the same Python scope can confuse MLIR (mainly for launching kernels).
  • The control‑flow rules are strict: no return inside a kernel; initialize everything.

If you’re exploring GPU kernels for the first time, I strongly recommend starting with Triton. When you need to really get into the weeds, or want to reuse CUTLASS building blocks, its great to have CuTe‑DSL as an option in Python (provided you’re comfortable spelunking in GPU internals).

  1. I spent a lot of time holding it wrong. Arguably, still holding it wrong. ↩︎
  2. No one knows what it means, but it’s provocative ↩︎
  3. Note the explicit unroll tag. When you really want #pragma but can’t. ↩︎

Discover more from Ian’s Blog

Subscribe now to keep reading and get access to the full archive.

Continue reading