← NVFP4 Kernels View source on GitHub ↗

Group GEMM

Group GEMM Operation

For Group GEMM we are given a set or "group" of GEMM operations that may have different shapes. The kernel must compute each of the GEMMs and store result to separate result buffers. Shown below are the benchmarks we will use to measure our kernel:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]}
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]}
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]}

'g' gives us the number of individual GEMM problems to be computed. The following lists of length 'g' for 'm', 'n', and 'k' detail the problem shape for each GEMM. For each i in [0, g-1] the M, N, K of the GEMM is given by m[i], n[i], k[i], where the m, n, and k lists are taken from the dictionaries above. As input to the kernel we are also given lists of the addresses for the A, B, and C buffers for each of the GEMMs. That data is stored in host memory.

Kernel Iteration and Code Walkthroughs

Now let's iterate through the different kernels I developed and examine the failures and successes.

- V0 (submission_v1.py) -

As with the previous 3 kernels the baseline solution is to iterate through all of the GEMM problems and compute them one at a time with a separate kernel launch for each, as shown in the code below:

    result_tensors = []
    for i, (
        (a_ref, b_ref, c_ref),
        (sfa_ref, sfb_ref),
        (m, n, k, l),
    ) in enumerate(
        zip(
            abc_tensors,
            sfasfb_tensors,
            problem_sizes,
        )
    ):
        for l_idx in range(l):
            # Convert the scale factor tensor to blocked format
            scale_a = to_blocked(sfa_ref[:, :, l_idx])
            scale_b = to_blocked(sfb_ref[:, :, l_idx])
            # (m, k) @ (n, k).T -> (m, n)
            res = torch._scaled_mm(
                a_ref[:, :, l_idx].view(torch.float4_e2m1fn_x2),
                b_ref[:, :, l_idx].transpose(0, 1).view(torch.float4_e2m1fn_x2),
                scale_a.cuda(),
                scale_b.cuda(),
                bias=None,
                out_dtype=torch.float16,
            )
            c_ref[:, :, l_idx] = res
        result_tensors.append((c_ref))
    return result_tensors

Again we encounter the same bottlenecks we saw in previous kernels using PyTorch's torch._scaled_mm. In this case we are serializing across the GEMM for each "group", so we aren't computing multiple GEMMs at the same time when the hardware could very well be capable of doing that. In this case our performance suffers far worse from this serialization and launching of multiple kernels because we are iterating over torch._scaled_mm many times rather than just once or twice like in previous kernels. This multiplies the impact by 'g'.

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 47318.6us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 24315.7us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 5488.6us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 2805.5us

- V1 (submission_v2.py) -

Like in the dual GEMM kernel we convert to using our PTX GEMM implementation and build off of that. The initial problem we have to solve in this case is how to handle a bunch of different GEMMs within the same kernel launch. As we will see later on there are multiple approaches to solving the problem of passing multiple GEMMs to a kernel and mapping CTAs to the correct result blocks. In this first iteration we take the following and arguably simplest approach:

Create a list of pointers and sizes, in this kernel we create a PyTorch tensor to pass to the inline C++ kernel wrapper (this is an inefficiency we will address in later kernels):

    A_ptrs = torch.tensor([a.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu')
    B_ptrs = torch.tensor([b.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu')
    C_ptrs = torch.tensor([c.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cuda')
    SFA_ptrs = torch.tensor([sfa.data_ptr() for (sfa,sfb) in sfasfb_tensors_reordered], dtype=torch.uint64, device='cuda')
    SFB_ptrs = torch.tensor([sfb.data_ptr() for (sfa,sfb) in sfasfb_tensors_reordered], dtype=torch.uint64, device='cuda')
    M_sizes = torch.tensor([m for (m,n,k,_) in problem_sizes], dtype=torch.int32, device='cpu')
    # OPT: uniform in benchmark, pass as constant?
    N_sizes = torch.tensor([n for (m,n,k,_) in problem_sizes], dtype=torch.int32, device='cpu')
    K_sizes = torch.tensor([k for (m,n,k,_) in problem_sizes], dtype=torch.int32, device='cpu')

We then use these metadata arrays to create lists of tensor maps and result tile counts to pass to the kernel.

    // Constants
    constexpr int M_TILE_SIZE = 128;
    constexpr int K_MMA_SIZE = 64;
    constexpr int NUM_WARPS = 6;

    // Configurables
    constexpr int N_TILE_SIZE = 64;
    constexpr int K_TILE_SIZE = 256;
    constexpr bool SWIZZLE = true;
    constexpr int PIPE_STAGES = 6;

    constexpr CUtensorMapSwizzle SWIZZLE_TYPE = SWIZZLE ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;

    CUtensorMap* A_tmaps = (CUtensorMap*) malloc(G * sizeof(CUtensorMap));
    CUtensorMap* B_tmaps = (CUtensorMap*) malloc(G * sizeof(CUtensorMap));
    int* block_totals = (int*) malloc(G * sizeof(int));

    // Allocate on device
    CUtensorMap* d_A_tmaps;
    CUtensorMap* d_B_tmaps;
    int* d_block_totals;
    cudaMalloc(&d_A_tmaps, G * sizeof(CUtensorMap));
    cudaMalloc(&d_B_tmaps, G * sizeof(CUtensorMap));
    cudaMalloc(&d_block_totals, G * sizeof(int));

    int* d_M_data;
    int* d_N_data;
    int* d_K_data;
    cudaMalloc(&d_M_data, G * sizeof(int));
    cudaMalloc(&d_N_data, G * sizeof(int));
    cudaMalloc(&d_K_data, G * sizeof(int));

    uint64_t* A_ptrs_data = A_ptrs.data_ptr<uint64_t>();
    uint64_t* B_ptrs_data = B_ptrs.data_ptr<uint64_t>();
    int* M_data = M_sizes.data_ptr<int>();
    int* N_data = N_sizes.data_ptr<int>();
    int* K_data = K_sizes.data_ptr<int>();

    for (int i = 0; i < G; ++i) {
        block_totals[i] = CEIL_DIV(M_data[i], M_TILE_SIZE) * CEIL_DIV(N_data[i], N_TILE_SIZE) + (i > 0 ? block_totals[i-1] : 0);
        tma_3d_map_ab<M_TILE_SIZE, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled, &A_tmaps[i], reinterpret_cast<void*>(A_ptrs_data[i]), M_data[i], K_data[i]);
        tma_3d_map_ab<N_TILE_SIZE, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled, &B_tmaps[i], reinterpret_cast<void*>(B_ptrs_data[i]), N_data[i], K_data[i]);
    }

    // Copy to device
    cudaMemcpy(d_A_tmaps, A_tmaps, G * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
    cudaMemcpy(d_B_tmaps, B_tmaps, G * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
    cudaMemcpy(d_block_totals, block_totals, G * sizeof(int), cudaMemcpyHostToDevice);

    cudaMemcpy(d_M_data, M_data, G * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(d_N_data, N_data, G * sizeof(int), cudaMemcpyHostToDevice);
    cudaMemcpy(d_K_data, K_data, G * sizeof(int), cudaMemcpyHostToDevice);

    auto kernel_inst = nvfp4_group_gemm_kernel<M_TILE_SIZE, N_TILE_SIZE, M_TILE_SIZE, N_TILE_SIZE, K_TILE_SIZE, M_TILE_SIZE, N_TILE_SIZE, K_MMA_SIZE, SWIZZLE, PIPE_STAGES, NUM_WARPS>;

    cudaFuncSetAttribute(
        kernel_inst,
        cudaFuncAttributePreferredSharedMemoryCarveout,
        cudaSharedmemCarveoutMaxShared  // Maximum shared memory
    );

    constexpr int threads = WARP_SIZE * NUM_WARPS;
    dim3 grid_dim(block_totals[G-1]);
    kernel_inst<<<grid_dim, threads>>>(reinterpret_cast<__half**>(C_ptrs.data_ptr()), d_A_tmaps, d_B_tmaps, reinterpret_cast<const uint8_t**>(SFA_ptrs.data_ptr()), reinterpret_cast<const uint8_t**>(SFB_ptrs.data_ptr()), d_M_data, d_N_data, d_K_data, d_block_totals, G);

Since we are going to use TMA to transfer data out of GMEM into our kernel we need to create tensor maps for all of the 'g' A and B matrices. As discussed in [TMA Details] tensor maps are an opaque structure, so we must rely on the CUDA driver API to create the structure on the host side. In otherwords we can't do anything clever by creating the tensor maps on device and parallelize that across different CTAs; however, we will see in later kernels how we can achieve a similar effect using certain Blackwell hardware features. For now we explicitly create and initialize all tmaps for A and B on the host side before launching the kernel. We malloc A_tmaps and B_tmaps as well as cudaMalloc'ing (which is just malloc for device memory) space to copy the tmaps to on device in d_A_tmaps and d_B_tmaps.

In addition to creating the A and B tmaps we also need to track how many result tiles there will be for each of the 'g' GEMMs. We do this by tracking a running sum for each GEMM via the block_totals array. block_totals[i] contains the total number of M_TILE_SIZE x N_TILE_SIZE blocks in GEMM[0] -> GEMM[i-1].

Any data in host memory that the kernel needs on device and isn't passed via kernel parameters must be explicitly copied from host memory to device memory. There are a number of ways to do this, see [Host/Device Data Transfer Details] for a deeper dive into this topic. In this case we use cudaMemcpy which uses HBM (High Bandwidth Memory) to transfer data between host and device. We transfer the A and B tmaps as well as the block_totals.

Each CTA will handle a single result tile so we launch as many CTAs as there are result tiles (block_totals[G-1] contains that total).

Once inside the kernel the last remaining step differentiating this kernel from normal GEMM is having each CTA figure out which GEMM it's actually computing as well as the position of the result tile within that GEMM it's responsible for.

    int group = 0;
    while (blockIdx.x >= block_totals[group]) group++;
    const int group_idx = blockIdx.x - (group > 0 ? block_totals[group - 1] : 0);
    const int n_tiles = CEIL_DIV(N_sizes[group], TD_CTA_N);
    const int row_idx = group_idx / n_tiles;
    const int col_idx = group_idx % n_tiles;
    const int m_off = row_idx * TD_CTA_M;
    const int n_off = col_idx * TD_CTA_N;

Each CTA's index (blockIdx.x) gets mapped onto the flattened array of result tiles, determining which GEMM and position within that GEMM this CTA is responsible for computing.

CTA → (group, tile) mapping diagram:

  Groups (G = 3 example, different M and N per group):

              ┌───┬───┐
              │ 0 │ 1 │
    group 0   ├───┼───┤    (2 × 2 result tiles)
              │ 2 │ 3 │
              └───┴───┘
              ┌───┬───┬───┐
              │ 4 │ 5 │ 6 │
    group 1   ├───┼───┼───┤  (2 × 3 result tiles)
              │ 7 │ 8 │ 9 │
              └───┴───┴───┘
              ┌────┬────┐
    group 2   │ 10 │ 11 │  (1 × 2 result tiles)
              └────┴────┘

  Flattened CTA index space:
    flat_idx :  0  1  2  3   4  5  6  7  8  9    10 11
                └──g0───┘    └──────g1─────────┘  └─g2─┘

  block_totals (running sum):
    block_totals[0] = 4   (g0 contributes 4 tiles)
    block_totals[1] = 10  (g0+g1 contribute 10)
    block_totals[2] = 12  (g0+g1+g2 contribute 12)

  Per-CTA resolution from blockIdx.x:
    1) linear/binary search block_totals to find group
    2) local_idx = blockIdx.x − block_totals[group − 1]
    3) row = local_idx / n_tiles[group]
       col = local_idx % n_tiles[group]
    4) m_off = row * M_TILE;  n_off = col * N_TILE

  Total CTA grid launched = block_totals[G − 1].

The rest of the kernel is the same as our GEMM kernel.

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 217.3us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 208.6us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 155.0us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 149.7us

- V2 (submission_v3.py) -

Before we start analyzing bottlenecks we can lay the foundation for an algorithmic improvement: non-persistent -> persistent kernel. In a non-persistent kernel we assign each CTA a fixed amount of work and divide the total work to be done by the kernel amongst each of the CTAs. This means many waves CTAs get launched on the GPU and there is no ability to explicitly overlap work from one wave of CTAs to the next. The name "non-persistent" refers to the fact that once a CTA has finished it's predetermined amount of work, it exits and is replaced by a different CTA performing the next set of work. In many cases a non-persistent kernel could be the correct approach; however, for group GEMM we could see substantial benefit from implementing a persistent kernel. A persistent kernel, as the name implies, has CTAs persist on the hardware across waves of work. Rather than completing one fixed portion of the output, a CTA can be responsible for multiple portions of the end result of the kernel.

For example, let's say we are computing a group GEMM that has 148*2 = 296 total output tiles. Our Blackwell hardware has 148 SMs, and for the purposes of this example let's assume our kernel demands resources such that we can only launch 1 CTA per SM. In a non-persistent kernel we would launch 296 CTAs where each CTA computes 1 output tile. In a persistent kernel we would launch only 148 CTAs, and each CTA would compute 2 output tiles each. This particular kind of persistent kernel is called a "static" persistent kernel because we are statically defining how many output tiles each CTA will compute. There is a way to do persistent kernels "dynamically" which we will discuss in later iterations of this kernel.

Persistent kernels can offer a number of benefits including data residency (pertinent data can stay resident in L1/SMEM across work tiles), reducing block deposit overhead, and potential cross tile pipeline overlap. In this kernel we mainly benefit from the latter two. Block depositing refers to the process of launching a CTA (or threadblock) onto a physical hardware unit (SM) when there are enough resources available. Cross tile pipeline overlap as demonstrated in the following diagram means we can overlap setup for the next work tile with the epilogue of the previous tile:

Non-persistent vs persistent overlap diagram:

  Non-persistent (one CTA per tile, CTA exits after one tile):

    SM-resident:
      CTA_A: [deposit][setup][load][mma][store][exit]
                                                     [deposit][setup][load][mma][store][exit]   ← CTA_B
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                                     full deposit + setup paid again per tile

  Static persistent (CTA stays resident, loops over multiple tiles):

    SM-resident:
      CTA_A: [deposit][setup][load tile 0][mma 0][store 0][load 1][mma 1][store 1][load 2][mma 2][store 2]
                            └──── tile 0 ────┘└──── tile 1 ────┘└──── tile 2 ────┘
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                              tile-N+1 setup/load can be moved
                                              into the previous tile's epilogue
                                              (cross-tile pipelining)

  Persistent kernels gain (a) one block-deposit per kernel instead of
  one per tile, (b) data residency for cluster/SMEM/L1 across tiles,
  and (c) ability to overlap tile N+1's prologue with tile N's
  epilogue once the right TMEM/SMEM double buffering is in place
  (see V13, V16, V22).

In this kernel we only implement the basic skeleton for a static persistent kernel. The major change is to wrap the TMA, MMA, and Epilogue warp sections in an outer work tile loop as shown:

    // Work-tile loop
    for (int tile_idx = blockIdx.x; tile_idx < block_totals[G-1]; tile_idx += gridDim.x) {
        // Grab correct data pointers, offsets, and TMA descriptors for this work tile
        int group = 0;
        while (tile_idx >= block_totals[group]) group++;
        int group_idx = tile_idx - (group > 0 ? block_totals[group - 1] : 0);
        int n_tiles = CEIL_DIV(N_sizes[group], TD_CTA_N);
        int row_idx = group_idx / n_tiles;
        int col_idx = group_idx % n_tiles;
        int m_off = row_idx * TD_CTA_M;
        int n_off = col_idx * TD_CTA_N;

        const CUtensorMap* tmap_a = &A_tmaps[group];
        const CUtensorMap* tmap_b = &B_tmaps[group];
        const uint8_t* sfa_gmem_base = sfa_ptrs[group];
        const uint8_t* sfb_gmem_base = sfb_ptrs[group];

Each CTA loops over all work tiles (also called output tiles) with a stride of gridDim.x, first computing the location of the work tile (as was done in the previous kernel during the setup phase) then proceeding with the normal GEMM algorithm. Notice how this is a new stage of work that can be included in the software pipeline and overlapped between tiles.

The static persistent kernel does come with a drawback in that we are affining specific sets of work tiles to a sinlge CTA. If a CTA happens to complete work slower, and there are free CTAs/SMs that could be performing the remaining work, we lose out on some hardware utilization and parallelism because that work as been programmed to run sequentially on the slower CTA. In this barebones implementation of a static perisistent kernel we don't overlap the "compute next work tile" phase with other phases in the software pipeline, so our kernel performance suffers slightly compared to the dynamic kernel for the aforementioned reason.

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 215.6us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 214.1us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 167.8us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 162.9us

- V3 (submission_persistent_kernel.py) -

At this point in kernel development I decided to insert CUDA events and use timers to figure out the full timing breakdown of my kernel. In otherwords, how much time is spent on setup versus inside the actual kernel itself. You can view the debug and timing harnesses in the code itself. I discovered what should be rather intuitive for small problem shapes: we are spending a large portion of the total runtime on setup and teardown relative to the GPU kernel runtime. Setup and teardown includes allocation of necessary host and device buffers (via malloc and cudaMalloc) and teardown involves cleaning up allocated resources via free and cudaFree as well as device synchronization.

When designing a kernel it's also important to consider factors external to the kernel itself. In this case we note that the kernel (including setup and teardown) will be called many times in a row for benchmarking and correctness tests. It isn't necessary for us to perform the memory allocation / deallocation steps each time, since that memory can simply be re-used until the overall program exits.

KEY CHANGES from v3:
1. PersistentBuffers struct holds all device/host buffers as static globals
2. Buffers are only allocated when capacity is insufficient (first call or if G grows)
3. Uses pinned host memory (cudaMallocHost) for faster H2D transfers
4. cudaFree() from the hot path
5. Driver function pointer is cached after first lookup

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 162.1us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 158.3us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 118.6us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 116.3us

- V4 (submission_v4.py) -

This kernel is a separate branch off of submission_v3.py (V2). Instead of focusing on eliminating overhead from allocation / de-allocation, we are focus on reducing the amount of problem shape data we need to track and transfer. Notice in the benchmark problem shapes the values for N and K across groups is constant, so don't need to keep track of G different N and K values. Working with just one value for N and K reduces the amount of data we need to transfer through our kernel wrappers (it also reduces the number of temporary torch tensors we need to create).

The drawback here is we lose some generality. However, the structure of high performance GPU libraries is usually similar: for example a call to CUTLASS GEMM goes through many compile time and runtime decision trees before landing on a hyper specialized routine built to perform very well for that particular problem configuration. In the context of a general group GEMM kernel this "fixed N/K" kernel would called anytime N and K are invariant for a group GEMM call, otherwise a different kernel supporting variable N and/or K would be used.

We see significant improvement in this kernel over V2 (though not as much as V3 saw over V2).

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 180.4us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 170.6us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 129.2us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 123.5us

- V5 (submission_v4_1.py) -

One additional potential bottleneck in the setup code is the creation of temporary pytorch tensors shown below:

A_ptrs = torch.tensor([a.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu')
B_ptrs = torch.tensor([b.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu')
C_ptrs = torch.tensor([c.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cuda')
SFA_ptrs = torch.tensor([sfa.data_ptr() for (sfa,sfb) in sfasfb_tensors_reordered], dtype=torch.uint64, device='cuda')
SFB_ptrs = torch.tensor([sfb.data_ptr() for (sfa,sfb) in sfasfb_tensors_reordered], dtype=torch.uint64, device='cuda')
M_sizes = torch.tensor([m for (m,n,k,_) in problem_sizes], dtype=torch.int32, device='cpu')

In general creating objects in C++ can be very costly, especially for latency sensitive applications running on the order of microseconds. This is because most complex object creation (including pytorch tensor creation) involves heap allocations and some setup code. For that reason it's best to avoid creating objects unless you explicitly need to in order to achieve some functionality or acheive performance gains in a more impactful way.

In submission_v4_1.py we avoid the pytorch tensor creation by passing in a vector of tuples using the C++ standard library (which is very well optimized):

const std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>& abc_tensors

This allows us to access the pointers to the A, B, and C matrices without first creating wrapper pytorch tensor objects. However, as you'll note from the benchmark results the performance doesn't really change. This invariant behavior stays consistent when I tested it with all of the future kernels we'll dive into, so it turns out that the cost benefit balance is equal for the two paradigms: creating pytorch tensors and passing them to the kernel, and creating standard library vectors out of python lists and passing it to the kernel. This was a surprise to me, and I'm still not 100% certain as to why this is the case, but here is my hypothesis:

The first paradigm builds two small CPU tensors of raw uint64 pointers: A_ptrs = torch.tensor([a.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu') B_ptrs = torch.tensor([b.data_ptr() for (a,b,c) in abc_tensors], dtype=torch.uint64, device='cpu') These are 8×8 = 64 bytes each. The C++ extension receives them as plain torch::Tensor and calls .data_ptr<uint64_t>() — essentially just a pointer dereference.

In the second, the raw Python list of tuples is passed directly: C++ signature: nvfp4_group_gemm(const std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>& abc_tensors, ...)

Crossing the Python/C++ boundary this way forces pybind11 to:
1. Iterate the Python list
2. Unpack each Python tuple into 3 elements
3. Type-check and convert each Python object to a torch::Tensor (refcount bump, metadata copy, type validation)
4. Heap-allocate a std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>

Thus, in either case we end up creating temporary pytorch objects, either explicitly or through pybind11 (which is an API to interface python with C++ code).

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 180.0us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 170.1us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 130.6us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} -> 124.2us

- V6 (submission_v4_2.py) -

In light of the analysis done in the previous kernel, how we pass the data from Python to C++ doesn't really matter. Thus, for the sake of cleanliness we stick to using the standard library and pybind11 to pass the pointers to the C++ kernel wrapper. At this point we've consolidated to a single kernel launch, explored persistent kernels, and eliminated buffer allocation/de-allocation overheads.

Now we analyze the next bottleneck. The last remaining time consuming item in the setup code is the creation of the tensor maps. Each GEMM requires the creation of 2 or 3 tensor maps, so for G = 8 that's 12-24 tensor maps that need to be allocated on the host, initialized, and copied to device memory before launching the kernel. At this point in kernel development I didn't know how much time tensor map initalization took relative to the actual memory allocation. We will revisit that later, but for now we explore a hardware feature that allows us to prevent the setup overhead of tensor maps from scaling linearly with the number of groups.

This feature is called tensormap.replace.

tensormap.replace.mode.field1{.ss}.b1024.type  [addr], new_val;
tensormap.replace.mode.field2{.ss}.b1024.type  [addr], ord, new_val;
tensormap.replace.mode.field3{.ss}.b1024.type  [addr], new_val;

.mode    = { .tile }
.field1  = { .global_address, .rank }
.field2  = { .box_dim, .global_dim, .global_stride, .element_stride  }
.field3  = { .elemtype,  .interleave_layout, .swizzle_mode, .swizzle_atomicity, .fill_mode }
.ss      = { .global, .shared::cta }
.type    = { .b32, .b64 }

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-tensormap-replace

The tensormap.replace instruction replaces the field, specified by .field qualifier, of the tensor-map object at the location specified by the address operand addr with a new value. The new value is specified by the argument new_val.

This instruction allows for inidividual CTAs to modify a tensor map to target a specific segment of data from an arbitrary matrix. Since tensor maps are opaque structures and many attributes of the tensor maps will remain constant we still create template tensor maps on the host side. These host side templates are then distributed as grid constant parameters at kernel launch, and each CTA modifies the tensor map to load the appropriate data in parallel.

Using this approach we avoid the need to allocate memory for tensor maps on the host side and parallelize the tensor map creation process, preventing the kernel setup runtime from scaling linearly with the number of GEMM groups.

The main drawback of using tensormap.replace in this way is device memory usage. Each CTA will need to manage and update it's own tensor maps, and tensor maps can only be used for TMA transfers if they reside in global memory. It's not possible to, for example, load and modify a tensor map in SMEM and use that SMEM tensor map for a TMA transfer. Thus we need to allocate room for 2 or 3 tensor maps for each CTA we are launching. In dynamic kernels this can be very expensive if you launch a large number of CTAs; however, with static persistent kernels the allocation overhead is greatly reduced (an unforseen advantage to static persistent kernels).

Let's walk through the key pieces of code that implement the tensor map replace scheme.

tma_3d_map_ab<M_TILE_SIZE, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled_fn, &tmap_a_temp, nullptr, M_TILE_SIZE, K);
tma_3d_map_ab<N_TILE_SIZE, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled_fn, &tmap_b_temp, nullptr, N, K);

First we create template tensor maps for A and B (if we use TMA for stores we would need and additional C matrix tensor map template). Notice we leave the address as null and give M_TILE_SIZE as a place holder value because that will also need to be changed depending on the group for which a CTA is computing a result tile.

cudaMalloc(&d_tmaps, 2 * NUM_CTAS * sizeof(CUtensorMap));

As mentioned we need space in GMEM for each CTA to store the modified tensor maps specific to that CTA.

// GMEM cache for CTA specific tensor maps
CUtensorMap* g_A_tmap = d_tmaps + 2*blockIdx.x;
CUtensorMap* g_B_tmap = g_A_tmap + 1;

__shared__ CUtensorMap local_A_tmap; // ISSUE: Do these need to be aligned?
__shared__ CUtensorMap local_B_tmap;

if (warp_id == 0 && elect_one_sync()) {
    local_A_tmap = tmap_a_temp;
    local_B_tmap = tmap_b_temp;
}

Inside the kernel itself we first find the pointers for this CTA to the modified tensor map storage in GMEM (g_A_tmap and g_B_tmap). Next we allocate spce in SMEM to store the tensor map templates for A and B. We will use these templates in SMEM to create the modified tensor maps for each work tile so we only need to load the template from GMEM once. We do that load with local_A_tmap = tmap_a_temp; and local_B_tmap = tmap_b_temp;.

bool update_group = (tile_idx >= next_group_start);

update_group is simply a bool that indicates whether the current work tile resides in a new GEMM group. If it does that means we need to update our tensor maps to properly reflect the addresses and shapes of the matrices.

if (update_group) {
    sfa_gmem_base = groups[group].sfa_addr;
    sfb_gmem_base = groups[group].sfb_addr;

    if (elect_one_sync()) {
        // Adjust M-dim value
        asm volatile(
            "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], %1, %2;"
            :
            : "r"(local_A_tmap_addr), "n"(1), "r"(groups[group].M)
        );

        // Update base addresses
        asm volatile(
            "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
            :
            : "r"(local_A_tmap_addr), "l"(groups[group].A_addr)
        );

        asm volatile(
            "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
            :
            : "r"(local_B_tmap_addr), "l"(groups[group].B_addr)
        ); 
    }

    __syncwarp();

    asm volatile(
        "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
        :
        : "l"(g_A_tmap), "r"(local_A_tmap_addr)
    );
    asm volatile(
        "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
        :
        : "l"(g_A_tmap)
    );

    asm volatile(
        "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
        :
        : "l"(g_B_tmap), "r"(local_B_tmap_addr)
    );
    asm volatile(
        "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
        :
        : "l"(g_B_tmap)
    );
}

Inside the TMA warp if we need to update the tensor map we perform three tensormap.replace operations: one to update matrix A's M dimension (since this varies across groups) and two to update the address of the A and B matrices. Only one thread in the TMA warp executes these instructions, but the warp should be synchronized for the next phase of the update. We issue three tensormap.cp_fenceproxy instructions.

tensormap.cp_fenceproxy.cp_qualifiers.fence_qualifiers.sync.aligned  [dst], [src], size;

.cp_qualifiers    = { .global.shared::cta }
.fence_qualifiers = { .to_proxy::from_proxy.release.scope }
.to_proxy::from_proxy  = { .tensormap::generic }
.scope            = { .cta, .cluster, .gpu , .sys }

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy

The tensormap.cp_fenceproxy instructions perform the following operations in order:

1) Copies data of size specified by the size argument, in bytes, from the location specified by the address operand src in shared memory to the location specified by the address operand dst in the global memory, in the generic proxy.

2) Establishes a uni-directional proxy release pattern on the ordering from the copy operation to the subsequent access performed in the tensormap proxy on the address dst.

In short, this serves as a fence to make sure future TMA operations see the correct and fully updated tensor map in GMEM (ensuring the copy of the modified tensor map from SMEM to GMEM has completed).

Then the rest of the kernel operates in the same way, just with the updated per CTA tensor maps. With this modification we see a large jump in performance due to the reduction in setup overhead and parallelization of tensor map initialization.

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 94.0us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 96.2us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 35.8us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} ->  31.0us

- V7 (submission_v4_3.py) -

At this point we've done most of what we can to eliminate setup and teardown overhead within our kernel. We will uncover later that this is only a half truth: we've made most of the optimizations for a kernel that would run on large problem shapes (i.e. large M, N, K, G), but for small problem shapes there are other ways to make the setup and teardown even more efficient. For now we transition back to the kernel itself and further optimizations we can make on top of the optimizations we've built up over the previous two kernels (GEMM and Dual GEMM).

Due to compute resource constraints I wasn't able to collect all of the NCU statistics I would have liked throughout the optimization process; however, I was able to use mbarrier stall statistics to estimate which section of the kernel was the bottleneck. Although the TMA and MMA portions take up a significant portion of runtime (indicated by stalls on the epilogue mbarrier), we've already heavily optimized those two portions over the past two kernels. The second large contributor to runtime and thus the bottleneck is the epilogue, and there are a number of optimization tricks we've yet to implement. We start in V7 by using TMA to store the result tiles to GMEM as opposed to using standard register file -> GMEM stores.

We've already covered the specifics of TMA transfers from GMEM -> SMEM. The reverse direction looks nearly identical with a few key exceptions.

TMA SMEM -> GMEM variant:

// shared::cta -> global

cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.level::cache_hint}
                                   [tensorMap, tensorCoords], [srcMem] {, cache-policy}

.dst =                  { .global }
.src =                  { .shared::cta }
.dim =                  { .1d, .2d, .3d, .4d, .5d }
.completion_mechanism = { .bulk_group }
.load_mode =            { .tile, .tile::scatter4, .im2col_no_offs }
.level::cache_hint =    { .L2::cache_hint }

The key differences to note in this variant are pointer locations and the completion mechanism. In this variant we provide the SMEM address as the srcMem argument. Instead of using mbarriers to track the completion of this transfer, a different mechanism called bulk_group is used. The bulk group mechanism, as the name suggests, is a way to track the completion of a bulk amount of data transfers. This is very similar to the tcgen05.commit instruction covered in the GEMM kernel which allowed for waiting on a group of tcgen05.mma instructions to complete. In this case we use a different commit instruction than tcgen05.commit:

cp.async.bulk.commit_group

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-bulk-tensor-copy-completion

cp.async.bulk.commit_group instruction creates a new per-thread bulk async-group and batches all prior cp{.reduce}.async.bulk{.prefetch}{.tensor} instructions satisfying the following conditions into the new bulk async-group:

The prior cp{.reduce}.async.bulk{.prefetch}{.tensor} instructions use bulk_group based completion mechanism, and

They are initiated by the executing thread but not committed to any bulk async-group.

Once we've created a bulk group using cp.async.bulk.commit_group we need to then wait on the completion of all the instructions included in that bulk group. The instruction used for this is the following:

cp.async.bulk.wait_group{.read} N;

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group

Causes the executing thread to wait until only N or fewer of the most recent bulk async-groups are pending and all the prior bulk async-groups committed by the executing threads are complete. For example, when N is 0, the executing thread waits on all the prior bulk async-groups to complete.

The optional .read modifier indicates that the waiting has to be done until all the bulk async operations in the specified bulk async-group have completed:

Reading from the tensormap

The reading from their source locations.

Including .read modifier essentially frees up waiting threads earlier because it doesn't have to wait until the writes finish in GMEM, only when all of the data has been read from SMEM and the tensor map used for the transfer. That being said, my experiments showed very little difference in performance between using .read and not.

In order to implement TMA store in our kernel we add the same tensor map machinery we've discussed for the C matrix tensor map. Then we make the following code changes in the epilogue code:

__shared__ alignas(128) __half c_smem[C_SMEM_TILESZ];

First we reserve a tile the size of a result tile in SMEM to stage the results for the TMA transfer. Then we use the same epilogue structure we've had all along to fill out the staging SMEM tile.

 // Ensure SMEM writes for all epilogue warps have completed, then issue TMA store to C in GMEM
if (warp_id == 0 && elect_one_sync()) {
    tcgen05_2dtma_s2g_c(c_smem_ptr, g_C_tmap, m_off, n_off, CacheHintSm100::EVICT_NORMAL);
    asm volatile("cp.async.bulk.commit_group;");
    asm volatile("cp.async.bulk.wait_group 0;"); // ISSUE: add .read or no?
}

Once we've written our results for the current work tile to SMEM we initiate the TMA transfer, add that TMA to a single bulk group, and wait on that bulk group's completion. Then we can loop and handle the next work tile.

The work in the MMA and TMA warps remains the same.

Due to the larger size of our work tiles and the fact that we are transferring more data overall from the SMs to GMEM (more total result data in the group GEMM problem shapes), using TMA to store the result data massively speeds up the epilogue stores, so we see a significant boost in performance.

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 58.0us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 51.6us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 24.7us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} ->  20.9us

- V8 (submission_v4_4.py) -

Before diving too deep into further kernel optimizations there is another hardware feature that could potentially improve the performance of the kernel: Cluster Launch Control (CLC).

Cluster Launch Control (CLC Details) is a hardware feature introduced in the Blackwell architecture that allows for dynamic persistent kernels. These kernels are persistent in a way very similar to static persistent kernels where CTAs can stay resident on an SM and compute multiple output tiles; however, they differ in that the work each CTA ends up computing is determined dynamically at runtime rather than being defined statically.

The idea is we launch the same amount of CTAs as if each CTA were going to compute a single result tile; however, instead of each CTA exiting once it has completed the work for a single result tile it checks to see if any CTAs have yet to be launched on the GPU, and if so, it takes the work from that CTA and computes the result itself. This accomplishes the same effect as the static persistent kernel by avoiding CTA launch overhead and allowing next tile computation overlap. The main benefit is that we don't run into the issues we encountered by affining certain result tile sets to certain CTAs.

Now we review the code that implements CLC:

CLC needs a place in SMEM to store it's results, we allocate that here:

__shared__ uint4 clc_result; // result for CLC cancel requests
int clc_result_addr = static_cast<int>(__cvta_generic_to_shared(&clc_result));

This is where CLC attempts to grab a work tile that hasn't started yet. A helpful mental model is to think about all the CTAs (technically all clusters, but for now in our kernel these are the same thing) as a pool of work or result tiles to be computed, and all active CTAs are continously trying to "steal" work tiles from this pool. That's why in CLC nomenclature it's called "work stealing", because you're stealing work from a pool of work tiles when that tile originally belonged to another CTA that hasn't launched yet.

At a more technical level we first ensure that the code isn't simultaneously trying to write to and read from the clc_result buffer in SMEM via the async and generic proxies. CLC writes the results of a "steal" operation via the async proxy to clc_result, the SMEM buffer, and the kernel reads that result using the generic proxy. For more details on memory proxies see that subsection in [Memory Consistency Details]. To ensure there isn't any read/write overlap causing stale data reads or corrupted data reads we insert a proxy fence between generic and async proxies.

Next we issue the next work steal attempt for this CTA via clusterlaunchcontrol.try_cancel

The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch of a cluster that has not started running yet. It asynchronously writes an opaque response to shared memory indicating whether the operation succeeded or failed. The completion of the asynchronous operation is tracked using the mbarrier completion mechanism at .cluster scope. On success, the opaque response contains the ctaid of the first CTA of the canceled cluster; no other successful response from other clusterlaunchcontrol.try_cancel operations from the same grid will contain that id.

In "cancelling" another clusters work, the executing cluster is essentially stealing that work tile. This is an asynchronous instruction, so we track it's completion with an mbarrier, similar to tcgen05 instructions.

Lastly, we set the CLC mbarrier to expect the number of bytes that CLC writes to clc_result via the async proxy, sizeof(uint4). As mentioned this contains the ctaid of the cancelled (or stolen) work.

// Async cancellation request
asm volatile("fence.proxy.async::generic.acquire.sync_restrict::shared::cluster.cluster;");
if (warp_id == 0 && elect_one_sync()) {
    asm volatile(
        "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];"
        :
        : "r"(clc_result_addr), "r"(mbar_addr_clc)
    );
    mbar_arrive_expect(mbar_addr_clc, sizeof(uint4));
}
clusterlaunchcontrol.try_cancel.async{.space}.completion_mechanism{.multicast::cluster::all}.b128 [addr], [mbar];

.completion_mechanism = { .mbarrier::complete_tx::bytes };
.space = { .shared::cta };

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel

Wait until the asynchronous CLC cancel (steal) has completed.

// look at result of stealing next tile with CLC
mbar_wait(mbar_addr_clc, clc_phase);

Extract the ctaid from the opaque struct written by CLC.

tile_idx = clc_check_steal(clc_result_addr);
__device__ inline int clc_check_steal(const int clc_result_addr) {
    int res = -1;
    asm volatile(
        "{\\n"
        ".reg .pred P1;\\n"
        ".reg .b128 handle;\\n"
        "ld.shared.b128 handle, [%1];\\n"
        "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, handle;\\n"
        "@!P1 bra.uni DONE;\\n" // if query returned false, no more work tiles to be stolen, so return -1 by doing nothing
        "clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, handle;\\n" // set res to stolen blockIdx.x
        "DONE:\\n"
        "}"
        : "+r"(res)
        : "r"(clc_result_addr)
    );
    return res;
}

This routine hinges on the clusterlaunchcontrol.query_cancel.is_canceled operation:

clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 pred, try_cancel_response;

clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {xdim, ydim, zdim, _},  try_cancel_response;

clusterlaunchcontrol.query_cancel.get_first_ctaid{::dimension}.b32.b128 reg, try_cancel_response;

::dimension = { ::x, ::y, ::z };

PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel

Instruction clusterlaunchcontrol.query_cancel can be used to decode opaque response written by instruction clusterlaunchcontrol.try_cancel. After loading response from clusterlaunchcontrol.try_cancel instruction into 16-byte register it can be further queried using clusterlaunchcontrol.query_cancel instruction as follows: clusterlaunchcontrol.query_cancel.is_canceled.pred.b128: If the cluster is canceled successfully, predicate p is set to true; otherwise, it is set to false. If the request succeeded, the instruction clusterlaunchcontrol.query_cancel.get_first_ctaid extracts the CTA id of the first CTA in the canceled cluster. By default, the instruction returns a .v4 vector whose first three elements are the x, y and z coordinate of first CTA in canceled cluster. The contents of the 4th element are unspecified. The explicit .get_first_ctaid::x, .get_first_ctaid::y, or .get_first_ctaid::z qualifiers can be used to extract individual x, y or z coordinates into a 32-bit register.

So clc_check_steal declares the raw byte and predicate registers, loads the result of try_cancel into the raw bytes register, then uses query_cancel to decipher that opaque struct and return cta index of the stolen work.

The only thing we've changed in this iteration is moving from a static persistent kernel with work per CTA defined at compile time to a dynamic persistent kernel with work per CTA defined at runtime. Any ground we gain with pure dynamic persistance ends up being outweighed by the overhead of the CLC mechanism itself for these problem shapes. However, it's possible CLC could become advantageous over static kernels if we are able to better overlap work across tiles (which could hide any additional CLC mechanism latency).

Benchmark Results:

{'g': 8, 'k': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168], 'm': [80, 176, 128, 72, 64, 248, 96, 160], 'n': [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096]}  -> 61.0us
{'g': 8, 'k': [2048, 2048, 2048, 2048, 2048, 2048, 2048, 2048], 'm': [40, 76, 168, 72, 164, 148, 196, 160], 'n': [7168, 7168, 7168, 7168, 7168, 7168, 7168, 7168]} -> 54.3us
{'g': 2, 'k': [4096, 4096], 'm': [192, 320], 'n': [3072, 3072]} -> 24.4us
{'g': 2, 'k': [1536, 1536], 'm': [128, 384], 'n': [4096, 4096]} ->  22.0us

From here on we transition from large-step structural changes (TMA, tensormap.replace, CLC) to finer-grained pipelining and scheduling work. The benchmark scores below use the same four problem shapes shown at the top of this section; for brevity later entries refer to them as shape A, B, C, and D:

  shape A: g=8, k=7168, m=[80,176,128,72,64,248,96,160], n=4096
  shape B: g=8, k=2048, m=[40,76,168,72,164,148,196,160], n=7168
  shape C: g=2, k=4096, m=[192,320], n=3072
  shape D: g=2, k=1536, m=[128,384], n=4096

- V9 (submission_v4_5.py) -

Coming off V8, the dominant remaining cost inside the kernel itself is the epilogue. The CLC experiment in V8 actually regressed slightly compared to V7's static persistent kernel, so we set CLC aside for now and attack the epilogue directly. Two ideas land here:

(1) Chunked, double-buffered epilogue. Previously each work tile filled one contiguous N-wide SMEM staging tile of size TD_SMEM_M x TD_SMEM_N, then issued a single TMA store of that whole tile. The drawback is that all of the TMEM->reg->half-conversion->SMEM work has to finish before the single TMA store can be issued, and the store has to drain before the next tile's epilogue can begin. We split the epilogue along the N axis into chunks of size OUT_N_CHUNK=32 and ping-pong between two SMEM buffers:

constexpr int C_CHUNK_SMEM_TILESZ = TD_SMEM_M * OUT_N_CHUNK;
__shared__ alignas(128) __half c_smem[C_CHUNK_SMEM_TILESZ * 2];

for (int chunk = 0; chunk < TD_MMA_N / OUT_N_CHUNK; chunk++) {
    int buf = chunk & 0x1;

    // Wait for the previous TMA store using THIS buffer to have read SMEM
    if (chunk >= 2) {
        if (warp_id == 0 && elect_one_sync()) {
            asm volatile("cp.async.bulk.wait_group.read 1;");
        }
        asm volatile("bar.sync 2, %0;" :: "r"(WARP_SIZE * (NUM_WARPS - 2)) : "memory");
    }

    // TMEM -> regs -> half2 -> SMEM (this chunk only)
    for (int sub_m = 0; sub_m < rows_per_warp / 16; sub_m++) {
        tcgen05_ld<16, 256, OUT_N_CHUNK / 8>(results,
            tmem_addr_result + (((warp_id * rows_per_warp) + sub_m * 16) << 16)
                             + chunk * OUT_N_CHUNK);
        asm volatile("tcgen05.wait::ld.sync.aligned;");
        // ... half2 packing and writes into c_smem[buf]
    }

    asm volatile("bar.sync 2, %0;" :: "r"(WARP_SIZE * (NUM_WARPS - 2)) : "memory");

    if (warp_id == 0 && elect_one_sync()) {
        tcgen05_2dtma_s2g_c(c_smem_ptr + (C_CHUNK_SMEM_TILESZ * 2 * buf),
                            g_C_tmap, m_off, n_off + chunk * OUT_N_CHUNK,
                            CacheHintSm100::EVICT_NORMAL);
        asm volatile("cp.async.bulk.commit_group;");
    }
}

The key PTX is the cp.async.bulk.wait_group.read N form. The .read modifier was discussed but not actually used in V7 (submission_v4_3.py); this is the first iteration where it pays off. The condition we need at the top of each chunk is "the previous use of *this SMEM buffer* has finished reading SMEM", not "the previous TMA store has globally landed in GMEM" — .read releases the SMEM buffer the moment the TMA engine has consumed it, while the GMEM write may still be in-flight in the L2 / HBM path. The result is that chunk i+2's TMEM->SMEM work overlaps in time with chunk i's SMEM->GMEM TMA transfer.

Two new tcgen05_ld template specializations (<16, 256, 2> and <16, 256, 4>) are added because the chunk-size of 32 means we now need to load 2 or 4 256-bit blocks at a time instead of the previous 8/16/32.

(2) Smaller wins on the host side. h_groups is moved from malloc to cudaMallocHost, which makes it pinned (page-locked) memory, and the H2D transfer is switched from cudaMemcpy to cudaMemcpyAsync. Pinned memory lets the driver DMA directly without bouncing through an intermediate staging buffer, and the async form lets the descriptor transfer overlap with the rest of the host-side launch path (driver work, kernel parameter marshalling) instead of blocking.

Benchmark Results:

shape A -> 48.9us shape B -> 43.0us shape C -> 15.6us shape D -> 13.2us

- V10 (submission_v5.py) -

V10 is a focused micro-optimization on top of V9 around when the C tensor map gets patched on group transitions. In V9, when a CTA crosses a group boundary the TMA warp does three back-to-back tensormap.replace + three tensormap.cp_fenceproxy operations (one each for A, B, and C) before it can issue the first TMA load for the new group. But the C tmap is not actually needed until much later, when the epilogue begins. Pushing its update into the epilogue path means the fenceproxy round-trip for C overlaps with the in-flight A/B loads and the tcgen05.mma instructions themselves.

To keep the diff readable, V10 first factors the repeated tensormap-replace idioms into helper inlines:

__device__ inline void tmap_update(int local_tmap_addr, uint64_t addr) {
    asm volatile(
        "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
        :: "r"(local_tmap_addr), "l"(addr)
    );
}
__device__ inline void tmap_update(int local_tmap_addr, uint64_t addr, int M) {
    // address + M-dim update for A's varying M
    ...
}
__device__ inline void tmap_fence_proxy(CUtensorMap* g_tmap, int local_tmap_addr) {
    asm volatile(
        "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"
        :: "l"(g_tmap), "r"(local_tmap_addr)
    );
    asm volatile(
        "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;"
        :: "l"(g_tmap)
    );
}

The TMA warp now only patches A and B at the top of a new group; the epilogue warps own the C patch:

// in the TMA warp's group-change branch
tmap_update(local_A_tmap_addr, groups[group].A_addr, groups[group].M);
tmap_update(local_B_tmap_addr, groups[group].B_addr);
tmap_fence_proxy(g_A_tmap, local_A_tmap_addr);
tmap_fence_proxy(g_B_tmap, local_B_tmap_addr);

// in the epilogue, just before mbar_wait(mbar_addr_epi, ...)
if (update_group && warp_id == 0 && elect_one_sync()) {
    tmap_update(local_C_tmap_addr, groups[group].C_addr, groups[group].M);
    tmap_fence_proxy(g_C_tmap, local_C_tmap_addr);
}

tensormap.cp_fenceproxy is not free — it is a release fence on the tensormap proxy followed by an acquire, and the cost is roughly equivalent to a small GMEM round-trip because the descriptor lives in GMEM. Moving even one of the three off the critical-path-before-first-TMA-load saves that latency on every group boundary. The savings show up on the shapes where group transitions are frequent enough to be visible (shape B drops from 43.0us to 42.0us; shape C from 15.6us to 15.3us). The all-same-K shape A doesn't move because the kernel's bottleneck on that shape is not group transitions.

Two additional small host-side changes: the host-side launch is reordered so cudaMemcpyAsync of the group descriptors is queued *before* the expensive cuTensorMapEncodeTiled template calls, letting the copy overlap with the driver-side encode work. NVCC build flags drop -lineinfo and add --relocatable-device-code=false, which can produce tighter ptxas register allocation and instruction scheduling.

Benchmark Results:

shape A -> 48.9us shape B -> 42.0us shape C -> 15.3us shape D -> 13.2us

- V11 (submission_v5_iter.py) -

V11 attempts inter-tile TMA pipelining: instead of draining the K-stage pipeline at the end of each work tile and refilling it from scratch at the top of the next, keep the pipeline continuously running across tile boundaries. Each CTA computes a sequence of (m_off, n_off) tiles, so the K-direction TMA loads for tile i+1 can start as soon as the consumer has released a stage from tile i.

The state machine needed for this is a monotonically increasing glob_k_off counter that runs across tiles, plus a first_tile flag that gates the pipe-fill prologue:

// outside the work-tile loop
int glob_k_off = 0;
bool first_tile = true;

// inside the work-tile loop, TMA warp
if (first_tile) {
    // Standard pipe-fill: pre-issue PIPE_STAGES TMA loads
    for (int s = 0; s < PIPE_STAGES; s++) {
        tma_load_stage(/* k_off= */ s * TD_SMEM_K, s);
        glob_k_off += TD_SMEM_K;
    }
    k_off = PIPE_STAGES * TD_SMEM_K;
    first_tile = false;
} else {
    k_off = 0;  // jump straight into the steady-state loop
}

for (; k_off < K; k_off += TD_SMEM_K) {
    int stage = (glob_k_off / TD_SMEM_K) % PIPE_STAGES;
    int wait_phase =
        (((glob_k_off / TD_SMEM_K) / PIPE_STAGES) - 1) & 0x1;
    mbar_wait(mbar_addr_mma + stage * 8, wait_phase);
    tma_load_stage(k_off, stage);
    glob_k_off += TD_SMEM_K;
}

The MMA warp uses the same glob_k_off so that its mbarrier phase tracking lines up. The per-tile mbar_init re-init loop is removed and mbar_addr_epi becomes a normal two-phase mbarrier (epi_phase ^= 1) instead of being zeroed every tile.

In theory this hides a full pipe-fill of TMA latency per tile boundary. In practice, on these small problem sizes the win is eaten by the new serial dependency it introduces: every tile's MMA now starts by waiting on a TMA load that was issued only one stage earlier, so there is less slack in the pipeline at the worst time. The measured numbers regress slightly on every shape (49.4 / 42.5 / 16.3 / 13.3us vs. V10's 48.9 / 42.0 / 15.3 / 13.2). I keep V11 in the lineage because the cross-tile pipelining machinery becomes useful later when paired with multi-CTA TMEM double buffering.

Benchmark Results:

shape A -> 49.4us shape B -> 42.5us shape C -> 16.3us shape D -> 13.3us

- Failed Experiments around V11 -

submission_v5_1.py (hang). An aggressive restructure where the TMA warp becomes a fully-featured "admin" warp: it does the CLC try_cancel, the binary-search group lookup, and writes the resolved (group, m_off, n_off) into SMEM slots for the MMA and epilogue warps to consume, gated by new mbarriers mbar_addr_tile_ready, mbar_addr_mma_done, and mbar_addr_epi_done. The mbarrier arrival counts get subtly inconsistent (most barriers use count=1 but mma_done is count=2), and with PIPE_STAGES = 1 there is no slack for the producer-consumer handshake to recover from any phase mismatch. The kernel hangs.

submission_v5_iter_2.py (correctness errors). Same producer-consumer architecture as v5_1 but layered on top of v5_iter. Also embeds the groups table as __grid_constant__ GroupDescs (by-value, no GMEM load). The shape of the bug is that the TMA warp updates the per-tile metadata SMEM slots while the MMA and epilogue warps are still computing on the *previous* tile's offsets — n_off in particular is used inside SFB's TMEM address as (n_off % 128) / 32, and the release fence (fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster) does not order writes against all the consumers, so warps read torn offsets.

submission_v5_iter_static.py (correctness errors). Strips CLC entirely and reverts to a static persistent loop, but keeps the v11-style glob_k_off cross-tile pipelining and adds a single-arrival mbar_addr_epi_done for an MMA<-epi handshake. The handshake is broken: there are NUM_WARPS - 2 = 4 epilogue warps but the mbarrier was initialized with arrival count 1, so the MMA warp races ahead and overwrites TMEM result space before all epilogue warps finish reading.

- Failed Experiment: submission_v5_precomp_tma.py -

A side experiment to test the hypothesis that the in-kernel tensormap.replace machinery is too expensive and that pre-encoding one full tensormap per group on the host would be cheaper:

// host side
for (int i = 0; i < G; i++) {
    tma_3d_map_ab<...>::init(..., &h_a_tmaps[i], A_ptrs[i], M[i], K);
    tma_3d_map_ab<...>::init(..., &h_b_tmaps[i], B_ptrs[i], N,   K);
    tma_2d_map_c_init<...>(..., &h_c_tmaps[i], C_ptrs[i], M[i], N);
}
cudaMemcpyAsync(d_a_tmaps, h_a_tmaps, G * sizeof(CUtensorMap), ...);
// ... similar for b and c

// device side, on group change
const CUtensorMap* a_tmap = &d_a_tmaps[group];

The kernel-side win is real (no more tensormap.replace or cp_fenceproxy per group change), but the host-side cost grows linearly in G: G calls to cuTensorMapEncodeTiled plus a larger H2D transfer. On these problem sizes the host overhead dominates and the kernel itself becomes 51.8 / 46.5 / 24.6 / 19.3us. Worth noting because the eventual winning design (V19 onward) does exactly this kind of per-group host-side encoding — it only works there because the host-side cost is amortized by other improvements made in parallel, and because by then the kernel-side savings from removing tensormap.replace start to outweigh the host cost.

- V12 (submission_v5_static.py) -

Steps back from CLC entirely. CLC's value is dynamic load balancing across heterogeneous M-sizes, but on these small problem shapes (G=2 or G=8) the variance per CTA is small and the per-iteration clusterlaunchcontrol.try_cancel + mbar_wait costs an mbar phase wait on every tile. V12 launches NUM_CTAS = 148 (one per SM on B200) and runs a classic static persistent loop:

for (int tile_idx = blockIdx.x;
     tile_idx < total_tiles;
     tile_idx += gridDim.x) {
    // ... resolve group, compute, store
}

Two configuration knobs flip with this change: N_TILE_SIZE grows from 128 to 256, and PIPE_STAGES drops from 5 to 3 (a larger N tile means more SMEM per stage). The larger N tile increases arithmetic intensity per tile and reduces the total tile count, which helps shapes where the kernel is MMA-bound (shape B drops from 42.5us to 38.1us) but hurts shapes where the small total work means the loss of cross-tile TMA prefetch matters more than the per-tile gain (shape D regresses from 13.3us to 15.2us).

The kernel ends each tile with a hard __syncthreads() + per-tile mbar_init reinit loop, which is what makes the next change necessary.

Benchmark Results:

shape A -> 54.7us shape B -> 38.1us shape C -> 19.1us shape D -> 15.2us

- Failed Experiment: submission_v5_static_overlap.py -

Tries to let the next tile's TMA + MMA warps start work while the current tile's epilogue warps are still draining. Adds a new mbar mbar_addr_epi_done for the epi -> MMA back-edge, splits the per-tile barrier so only the TMA and MMA warps participate in the mbar reinit, and lets epilogue warps fall through to the next iteration without participating in any inter-tile sync:

if (warp_id == TMA_WARP || warp_id == MMA_WARP) {
    asm volatile("bar.sync 3, %0;" :: "r"(WARP_SIZE * 2));
    if (warp_id == TMA_WARP && elect_one_sync()) {
        for (int i = 0; i < PIPE_STAGES * 2; i++) mbar_init(...);
        asm volatile("fence.mbarrier_init.release.cluster;");
    }
    asm volatile("bar.sync 3, %0;" :: "r"(WARP_SIZE * 2));
}
if (warp_id == MMA_WARP) {
    mbar_wait(mbar_addr_epi_done, epi_done_phase);
    epi_done_phase ^= 1;
}

The kernel faults. The root cause is that there is still only one TMEM result region; epilogue(tile N) is reading from tmem_addr_result while MMA(tile N+1) is issuing tcgen05.mma ... enable_input_d=0 against the same address, zeroing it under the epilogue's feet. The phase-bookkeeping on mbar_addr_epi is also fragile because the epilogue warps don't participate in the mbar reinit but do read the mbar's phase counter. The fix needs a TMEM ping-pong, which V13 provides.

- V13 (submission_v5_static_overlap_2.py) -

Same overlap idea as the failed kernel above, but with the missing piece in place: two TMEM result regions and two pairs of (epi, epi_done) mbarriers, indexed by a per-warp tmem_buf that toggles each tile.

tcgen05_alloc_tmem<1>(tmem_addr_ptr, TD_MMA_N * 2 * 2);  // doubled

const int tmem_result_ptrs[2] = {
    tmem_addr_base,
    tmem_addr_base + 2 * TD_MMA_N
};
const int tmem_sfa_ptrs[2] = {
    tmem_result_ptrs[0] + TD_MMA_N,
    tmem_result_ptrs[1] + TD_MMA_N
};
const int tmem_sfb_ptrs[2] = {
    tmem_sfa_ptrs[0] + (TD_MMA_M / 32) * (TD_SMEM_K / 64),
    tmem_sfa_ptrs[1] + (TD_MMA_M / 32) * (TD_SMEM_K / 64)
};

int tmem_buf = 0;
int epi_phase[2]      = {0, 0};
int epi_done_phase[2] = {0, 0};

// MMA warp's per-tile prologue, gated by first_tile so the very first
// iteration doesn't wait on a back-edge that hasn't fired yet
if (warp_id == MMA_WARP) {
    tmem_buf ^= 1;
    if (!first_tile) {
        mbar_wait(mbar_addr_epi_done + tmem_buf * 8,
                  epi_done_phase[tmem_buf]);
        epi_done_phase[tmem_buf] ^= 1;
    } else {
        first_tile = false;
    }
}

// MMA emits results to the current tmem_buf
// Epilogue reads from the same tmem_buf, signals epi_done at the end

With two TMEM result halves, MMA(tile N+1) writes one half while epilogue(tile N) reads the other; the long TMEM->reg->half->SMEM->TMA chain now overlaps with the K-loop of the next tile. Configuration returns to N=128 / PIPE=5 because the doubled TMEM allocation forces TD_MMA_N=128 (TMEM column budget is exhausted at 512 cols).

A small but important PTX detail used throughout V12+ is the *named barrier* form of bar.sync:

bar.sync     a {, b};
bar.arrive   a, b;

  a — barrier resource (0..15)
  b — number of threads participating

PTX link: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar

CUDA's __syncthreads() compiles to bar.sync 0 and synchronizes every thread in the CTA. The two-operand form bar.sync N, M synchronizes only M threads on barrier resource N, where N is a small integer (0-15). This is how the kernel synchronizes *only the epilogue warps* without involving the TMA and MMA warps, e.g.:

asm volatile("bar.sync 2, %0;" :: "r"(WARP_SIZE * (NUM_WARPS - 2)) : "memory");

waits for WARP_SIZE * (NUM_WARPS - 2) threads — i.e. the four epilogue warps. Different barrier resources are used for different role-groups (barrier 2 for epilogue, barrier 3 for TMA+MMA in V12+) so that one role group can synchronize among itself while another is independently running.

Benchmark Results:

shape A -> 45.9us shape B -> 35.5us shape C -> 15.3us shape D -> 13.2us

- V14 (sub_v6.py) -

V14 (the start of the v6 family) takes a host-side optimization that pays off everywhere: stop transferring the GroupDesc array via cudaMemcpyAsync, and instead pass it by value as a __grid_constant__ kernel argument. The size limit on a kernel argument is 32KB and the GroupDescs fits in well under that bound:

struct GroupDescs {
    GroupDesc groups[MAX_G];   // MAX_G = 8
};

void nvfp4_group_gemm(...) {
    GroupDescs gd = {};
    for (int i = 0; i < G; i++) {
        gd.groups[i].A_addr   = A_ptrs[i];
        gd.groups[i].B_addr   = B_ptrs[i];
        gd.groups[i].C_addr   = C_ptrs[i];
        gd.groups[i].sfa_addr = sfa_ptrs[i];
        gd.groups[i].sfb_addr = sfb_ptrs[i];
        gd.groups[i].M        = M_sizes[i];
    }
    kernel_inst<<<NUM_CTAS, threads>>>(gd, tmap_a_temp, ...);
}

__grid_constant__ arguments live in a special constant memory bank that's broadcast to all CTAs; it's free to read on the device side and free to ship on the host side (the driver folds it into the launch parameters that travel through the existing launch path). This eliminates one cudaMemcpyAsync and the static cudaMallocHost / cudaMalloc pair entirely.

V14 also trims back the per-tile sync to a single mbar_addr_epi_done and drops the temporary loss of TMEM ping-ponging (single TMEM result region) to keep the diff isolated to the host-side change. NUM_CTAS is now min(total_tiles, 148) to avoid launching CTAs that have no work.

Benchmark Results:

shape A -> 43.6us shape B -> 33.5us shape C -> 14.3us shape D -> 10.5us

- V15 (sub_v6_1.py) -

Re-introduces the cross-tile TMA pipelining from V11 — this time on top of the cleaner static-persistent scheduling of V14. The state machine is the same glob_k_off running counter, with the first-tile prefetch and steady-state pipe consumption:

int glob_k_off = 0;
bool first_tile = true;
constexpr int prefetch_stages = PIPE_STAGES;

for (int tile_idx = blockIdx.x; tile_idx < total_tiles;
     tile_idx += gridDim.x) {
    // ... resolve tile ...

    if (warp_id == TMA_WARP) {
        if (first_tile) {
            for (int s = 0; s < prefetch_stages; s++) {
                tma_load_stage(/* k_off= */ s * TD_SMEM_K, s);
            }
            k_off = TD_SMEM_K * prefetch_stages;
            glob_k_off = k_off;
            first_tile = false;
        } else {
            k_off = 0;
        }
        for (; k_off < K; k_off += TD_SMEM_K) {
            int stage = (glob_k_off / TD_SMEM_K) % PIPE_STAGES;
            int wait_phase =
                (((glob_k_off / TD_SMEM_K) / PIPE_STAGES) - 1) & 0x1;
            mbar_wait(mbar_addr_mma + stage * 8, wait_phase);
            tma_load_stage(k_off, stage);
            glob_k_off += TD_SMEM_K;
        }
    }
    // MMA warp uses the same glob_k_off
}

The MMA warp consumes stages from the same glob_k_off, so when tile i finishes its K-loop the next iteration's first mbar_wait simply consumes one of the prefetched stages issued from tile i+1's TMA loads. Combined with the alignment fix alignas(128) CUtensorMap for the SMEM tensor map and a fence.proxy.async.shared::cta before each TMA store, the sub-microsecond wins land on shapes B and D's middle case (33.5 -> 33.0us).

The proxy fence is needed because the epilogue writes SMEM through the generic proxy (ordinary half2 stores) while the TMA engine reads SMEM through the async proxy; without an explicit proxy fence the TMA can read stale SMEM. The proxy framework itself is touched on in summary.txt's V6 (tensormap proxy via tensormap.cp_fenceproxy) and V8 (async vs generic proxy in CLC) entries, and listed under the Memory Consistency Details glossary section.

NOTE: This was a long and arduous debug. The root cause of the bug was non-fatal memory corruption due to faulty synchronization architecture, so this bug was manifesting as a non-deterministic correctness error. The PTX docs don't clearly and explicitly outline this potential race condition, and at the PTX level it seems like writes to and reads from the same SMEM should be a RAW hazard handled by the hardware even if the write is an SMEM store and the read is a TMA read; however, due to the different proxies a fence is needed.

Benchmark Results:

shape A -> 43.6us shape B -> 33.0us shape C -> 14.3us shape D -> 10.5us

- V16 (sub_v6_2.py) -

Re-introduces TMEM ping-ponging from V13, stacked on V15's cross-tile TMA pipelining:

tcgen05_alloc_tmem<1>(tmem_addr_ptr, TD_MMA_N * 2 * 2);

const int tmem_result_ptrs[2] = {base, base + 2 * TD_MMA_N};
const int tmem_sfa_ptrs[2]    = { result_ptrs[0] + TD_MMA_N,
                                  result_ptrs[1] + TD_MMA_N };
// ... two epi / epi_done mbars

// MMA tail: signal current tmem_buf, wait for the previous tmem_buf
tcgen05_commit(mbar_addr_epi + 8 * tmem_buf);
// ...
mbar_wait(mbar_addr_epi_done + tmem_buf * 8,
          epi_done_phase[tmem_buf]);

Now MMA(tile N+1) can write into one TMEM half while epilogue(tile N) drains the other, *and* TMA loads for tile N+1 are already in flight from the V15 cross-tile pipelining. All three phases — TMA, MMA, epilogue — overlap across tile boundaries. The benchmark moves uniformly by about 2us on the larger shapes (shape B 33.0 -> 31.1us) and 0.5us on the smaller ones (shape C 14.3 -> 13.9us). The smaller shapes benefit less because they don't have enough tiles to amortize the overlap benefit.

Benchmark Results:

shape A -> 41.8us shape B -> 31.1us shape C -> 13.9us shape D -> 10.5us

- V17 (sub_v6_2_3.py) -

The largest single-step improvement in the v6 family. Three changes stack:

(1) Cluster of two CTAs with TMA multicast of A. The kernel launches with __cluster_dims__(2, 1, 1) and only the leader CTA issues the A load:

__device__ inline void tcgen05_3dtma_g2s_ab_multicast(
        int smem_ptr, const CUtensorMap* tmap,
        int m_off, int k_off, int mbar_addr,
        uint16_t cluster_mask) {
    asm volatile(
        "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
        "complete_tx::bytes.multicast::cluster.cta_group::1.L2::"
        "cache_hint [%0], [%1, {%2, %3, 0}], [%4], %5, %6;"
        :
        : "r"(smem_ptr), "l"(tmap), "r"(k_off), "r"(m_off),
          "r"(mbar_addr), "h"(cluster_mask), "l"(L2_CACHE_HINT)
    );
}

The full multicast-TMA instruction syntax (with .multicast::cluster listed as an option) is already shown in the basic GEMM section. The key new behavior versus the non-multicast form: when .multicast::cluster is present the instruction takes an extra ctaMask operand (the "h"(cluster_mask) above) — a 16-bit bitfield where bit i set means "deliver this TMA's data to CTA i in the cluster, into the same SMEM offset, and arrive on each of those CTAs' mbarriers". The TMA engine reads from HBM/L2 exactly once and fans out via the cluster interconnect (DSMEM path), so we pay HBM bandwidth for one CTA's worth of A and get both CTAs filled. The complete_tx arrival happens on every receiving CTA, which is why every receiving CTA's mbar_init arrival count has to be bumped from 1 to CLUSTER_SIZE — otherwise the consumer side under-counts arrivals and mbar_wait never returns.

2-CTA cluster + multicast TMA on A diagram:

                              ┌────────────────┐
                              │   GMEM (A)     │
                              └────────┬───────┘
                                       │  (one HBM read)
                                       ▼
                              ┌────────────────┐
                              │   L2 / TMA     │
                              │     engine     │
                              └───┬────────┬───┘
              .multicast::cluster │        │ .multicast::cluster
                                  ▼        ▼
                        ┌─────────────┐  ┌─────────────┐
                        │   CTA 0     │  │   CTA 1     │
                        │  A_smem[i]  │  │  A_smem[i]  │  (identical
                        │             │  │             │   SMEM offset)
                        │ mbar_tma[s] │  │ mbar_tma[s] │  (arrives on
                        │  arrive: 2  │  │  arrive: 2  │   each CTA)
                        └─────┬───────┘  └──────┬──────┘
                              │ load own B      │ load own B
                              ▼                 ▼
                        ┌─────────────┐  ┌─────────────┐
                        │ MMA on      │  │ MMA on      │
                        │ A × B(0,:)  │  │ A × B(1,:)  │
                        └─────────────┘  └─────────────┘

  HBM bandwidth on A is halved (one fetch, two consumers). Both
  consumers must bump mbar arrival count to CLUSTER_SIZE=2 so that the
  consumer-side mbar_wait sees the expected number of arrivals.

The corresponding tcgen05.commit uses the multicast scope so a single MMA completion releases the shared mbarrier on every CTA in the cluster (the non-multicast form is the one shown in the basic GEMM kernel):

tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.
                                  multicast::cluster.b64 [%0], %1;

Two paired mbarrier-side details follow from the multicast:

(a) The basic-GEMM mbarrier.arrive.expect_tx helper is scoped to .shared::cta, which can only arrive on a barrier local to this CTA. With multicast, the producing CTA must signal an mbarrier that the *receiving* CTAs see, so the helper is upgraded to .shared::cluster — the mbarrier address is interpreted as a DSMEM address that may live on a peer CTA, and the arrival routes through the cluster interconnect.

(b) On a group transition the producer warp emits the standard fence.mbarrier_init.release.cluster (already seen in the basic GEMM kernel setup) so the new mbarrier state is visible to peer CTAs before they read it.

(2) Pre-built per-group tensormaps in GroupDesc. The in-kernel tensormap.replace + cp_fenceproxy chain is gone. Tensormaps are encoded on the host once per group and shipped with the GroupDescs struct:

struct GroupDesc {
    uint64_t A_addr, B_addr, C_addr, sfa_addr, sfb_addr;
    int M;
    CUtensorMap tmap_a, tmap_b, tmap_c;
};

// host side
for (int i = 0; i < G; i++) {
    tma_3d_map_ab<M_TILE_SIZE, K_TILE_SIZE, ...>::init(
        ..., &gd.groups[i].tmap_a, A_ptrs[i], M[i], K);
    tma_3d_map_ab<N_TILE_SIZE, K_TILE_SIZE, ...>::init(
        ..., &gd.groups[i].tmap_b, B_ptrs[i], N,    K);
    tma_2d_map_c_init<M_TILE_SIZE, OUT_N_CHUNK>(
        ..., &gd.groups[i].tmap_c, C_ptrs[i], M[i], N);
}

Because the whole GroupDescs struct is a __grid_constant__ argument (from V14), shipping the tensormaps adds nothing to the launch cost. On the device side, group transitions are just &groups[cur_group].tmap_a — no replace, no fenceproxy, no GMEM round-trip. This is exactly the experiment that failed as a standalone change in submission_v5_precomp_tma.py, but here it works because the cluster + multicast improvements have grown the kernel's runtime so the constant host-side encode cost is amortized.

(3) A dedicated tile-descriptor warp. With NUM_WARPS=7 the new warp precomputes (group, m_off, n_off, sfa_addr, sfb_addr, M) for each upcoming tile and publishes them through five SMEM slots gated by an mbar_addr_tile_ready mbarrier. This kills the redundant binary-search-for-group work that every other warp used to do, and overlaps it with the previous tile's MMA.

Benchmark Results:

shape A -> 38.3us shape B -> 24.9us shape C -> 12.8us shape D -> 9.1us

- Failed Experiment: sub_v6_2_descoverlap.py -

Tries the same dedicated tile-descriptor producer warp from V17 but layered on top of V16, which still has the tensormap.replace machinery. The tile-descriptor warp now does three tmap_update + three tmap_fence_proxy per tile, writing to per-stage GMEM tmap slots g_A_tmap + tile_stage. The intent was to pipeline descriptor preparation across tiles. But the GPU-scope tensormap.cp_fenceproxy ... gpu is a heavy serializer at L2/GPU scope, and doing three of them per tile in the steady state turns out to be much worse than doing three per *group transition* — which is what V16 did. The benchmark explodes to 59.3 / 42.7 / 21.5 / 16.7us. V17 only avoided this landmine because it had simultaneously eliminated tensormap.replace entirely.

Failed Experiment: sub_v6_2_multicast.py (hang). A standalone multicast attempt without the supporting mbar arrival-count fixes and without the cluster-scope commit. With multicast TMA delivering to both CTAs but the mbarriers still initialized for single-CTA arrival counts, the consumer's mbar_wait never sees enough arrivals and the kernel hangs forever. This is the cautionary lesson that the multicast / arrive-count / commit-scope changes need to land together.

Failed Experiment: sub_v6_3.py (correctness errors). Adds a transposed-output path: TMA loads stay the same, but the MMA swaps A and B, the result is computed effectively as B^T @ A in TMEM, and the epilogue stores into a transposed C using tensormap.replace.tile.global_stride to patch the per-dim byte stride (the full set of replaceable fields — global_dim, global_address, global_stride, element_stride, etc. — was listed in V6 / submission_v4_2.py). A transposed C needs both a different shape *and* a different stride, which is why both global_dim and global_stride get patched here. The SF TMEM layout doesn't get the matching swap (tmem_sfa_ptrs still uses TMEM_SF_COLS = TD_MMA_N for stride math but the MMA passes sfa/sfb swapped), and the epilogue's tcgen05_ld<16, 256, ...> pattern assumes results are laid out along the original M-axis.

Failed Experiment: sub_v6_splitk.py (hang). Split-K experiment with SPLIT_K=4, K_TILE_SIZE=64, PIPE_STAGES=15, and a 4-way TMEM circular buffer. The epilogue is replaced with an atomicAdd into a pre-allocated reduction buffer. The hang comes from the producer's glob_k_off continuing absolutely across shards while the consumer expects per-shard phase resets; subsequent shards' mbar_wait end up waiting on already-flipped phases.

Failed Experiment: sub_v6_tmem_circ.py (hang). A 4-deep TMEM ring without split-K. mbarrier phase bits only track a single toggle, so the producer-consumer phase agreement for a 4-cycle ring is fragile. The init_tmem_fill > 3 warmup gate introduces an off-by-one in when the wait actually starts, and the producer and consumer end up disagreeing about which phase they're in.

- V18 (sub_v6_3_test.py) -

A small follow-up to v6_3 that improves the transposed-output epilogue specifically. Switches to a different tcgen05.ld shape (the full instruction with all .shape / .num options is shown in the basic GEMM section). The earlier epilogues used .16x256b with .x8/.x16/.x32 (16 rows wide, 256 bits per "load atom"). The transposed-output epilogue uses .32x32b with .x32, meaning 32 rows of 32-bit data per atom and 32 atoms per instruction — each lane ends up with a tall column of values which is exactly what column-major SMEM staging wants. The two shapes correspond to very different physical TMEM access patterns; choosing the wrong one forces a transpose-in-registers (extra shuffles or extra SMEM round-trip), which is the cost v6_3 was paying before this fix.

Each thread now accumulates 32 register values that form a column of OUT_CHUNK halves, then stores them column-major into SMEM at c_smem + i * TD_SMEM_N + threadIdx.x. SF tmem indexing is also fixed (int sfa_tmem = tmem_addr_sfa + 4 * sub_k_iter + (m_off % 128) / 32). Bumps CTAS_PER_SM=2 (296 total CTAs).

V18 passes correctness, but it's not on the cluster-multicast path (it's __cluster_dims__(1, 1, 1)) and it still pays the runtime tensormap.replace cost on group boundaries, so the numbers are worse than V17. It's kept as a working baseline for the transpose layout experiments.

Benchmark Results:

shape A -> 46.3us shape B -> 35.8us shape C -> 14.9us shape D -> 12.5us

- V19 (sub_v7.py) -

V19 is a flexibility checkpoint, not a perf win. The intent is to put the runtime tensormap.replace machinery back so that the kernel can support per-group varying N and K (V17's pre-encoded tensormap path bakes N and K into the host-side encode call). To make the comparison clean V19 strips the multicast path and runs with __cluster_dims__(1, 1, 1). It also adds a TRANSPOSE template parameter that flips A and B in the MMA instruction.

The descriptor patching is per-CTA, per-stage, with the heavyweight GPU-scope fence:

"tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;"
"tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;"

published into a per-CTA GMEM tensormap cache:

CUtensorMap* d_tmaps;
cudaMalloc(&d_tmaps, 3 * TILE_DESC_PIPE_STAGES * NUM_CTAS * sizeof(CUtensorMap));
// per-CTA: g_A_tmap = d_tmaps + 3 * TILE_DESC_PIPE_STAGES * blockIdx.x

The regression vs. V17 (~3us across all shapes) is the cost of re-introducing the per-tile descriptor patching plus the loss of TMA multicast on A. This is the inverse of the descoverlap regression: V19 has only the descriptor-patching cost, and that alone is ~3us; descoverlap had the patching cost *and* did three patches per tile (vs. one per group transition here), and that was ~20us.

Benchmark Results:

shape A -> 41.3us shape B -> 27.8us shape C -> 14.8us shape D -> 12.4us

- Failed Experiments: sub_v8.py, sub_v8_1.py -

sub_v8.py (fails correctness). Re-adds the 2-CTA cluster on top of V19's transposed tensormap.replace machinery, plus a 2SM tcgen05 MMA (tcgen05.mma.sp.async.cta_group::2.kind::mxf4nvf4). This is a fundamentally different multicast mode from V17: cta_group::2 (covered in detail in the Dual GEMM V2 section) means the MMA itself is 2-CTA — each CTA in the cluster owns half of the TMEM result and the tensor cores collaborate across the SM boundary via the cluster interconnect. The effective N of the instruction descriptor doubles (make_instr_desc<TD_MMA_N * CTA_GROUP, TD_MMA_M>()), and the TMEM result is striped across both CTAs: CTA0 holds the left half along N, CTA1 holds the right half. Producer-consumer synchronization across the cluster uses the same barrier.cluster.arrive.release.aligned / barrier.cluster.wait .acquire.aligned pair shown in Dual GEMM V2.

The arrive-count on TMA producer mbars is bumped to CLUSTER_SIZE to account for the multicast arrivals, and only the leader CTA issues MMAs. There are two bugs: (1) a precedence issue in mbar_addr_tma + stage * 8 & 0xFEFFFFFF (the & binds wider than intended and corrupts the address), and (2) the cta_rank indexing in the epilogue's TMEM load reads from CTA1's TMEM half even though the result is in CTA0's TMEM. Either alone would cause incorrect output.

sub_v8_1.py (fails correctness). Switches the scale-factor TMA from 1D bulk to a true 3D tensor descriptor:

"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes.cta_group::%0.L2::cache_hint "
"[%1], [%2, {%3, %4, %5}], [%6], %7;"

with a corresponding tma_3d_map_sf encoder using rank-3 shape {256, 2 * (k_dim_gmem_sf / 4), mn_dim_gmem / 128}. The kernel is checked in with a debug stub still in place:

for (int chunk = 0; chunk < 0; chunk++)  // !!!
    /* TMEM->SMEM->TMA store loop */

so the entire epilogue is short-circuited and no results ever get written. Once that stub is fixed there are likely additional issues in the new 3D SF descriptor's stride/box layout, but the proximate failure is just the dead loop.

- V20 (v0.py) -

V20 is a recovery: drop the failed 2-CTA / SF-3D-TMA experiments and keep V19's tensormap.replace machinery but pair it with TMA multicast on A:

tcgen05_3dtma_g2s_ab_multicast<1, CLUSTER_MASK>(
    a_smem_stage_ptr, g_A_tmap + tile_stage,
    m_off, k_off, mbar_addr, CLUSTER_MASK);
// ... B is non-multicast
tcgen05_commit_multicast<CLUSTER_MASK>(...);

It also adds two specialization knobs:

  NK_VAR=true        -> generalize to per-group varying N and K
                        by calling tmap_update_dim<1>,
                        tmap_update_dim<2>, tmap_update_stride<0>
                        on each group transition.
  SINGLE_WAVE=true   -> when total_tiles < 148, elide the
                        back-half of the epilogue mbar wait
                        because there is no wave-2.

Four kernel template instantiations are compiled, and the host wrapper picks one based on (NK_VAR, SINGLE_WAVE). This is the state of the kernel right before the next big simplification.

Benchmark Results:

shape A -> 40.6us shape B -> 26.4us shape C -> 14.1us shape D -> 10.6us

- V21 (v1.py) -

V21 throws out the tensormap.replace machinery entirely (again) and goes back to per-group pre-encoded tensormaps embedded in GroupDesc, like V17, but this time preserving the NK_VAR capability — the per-group encode happens on the host with each group's actual M, N, K, so variable N/K is supported naturally without needing in-kernel descriptor patching:

// host side
for (int i = 0; i < G; i++) {
    tma_3d_map_ab<M_TILE_SIZE, K_TILE_SIZE, ...>::init(
        ..., &gd.groups[i].tmap_a, A_ptrs[i], M[i], K_sizes[i]);
    tma_3d_map_ab<N_TILE_SIZE, K_TILE_SIZE, ...>::init(
        ..., &gd.groups[i].tmap_b, B_ptrs[i], N_sizes[i], K_sizes[i]);
    tma_2d_map_c_init<...>(..., &gd.groups[i].tmap_c,
                          C_ptrs[i], M[i], N_sizes[i]);
}

// device side (no SMEM tmap, no fenceproxy)
tcgen05_3dtma_g2s_ab_multicast<1, CLUSTER_MASK>(
    a_smem_stage_ptr, &groups[cur_group].tmap_a, ...);
tcgen05_3dtma_g2s_ab<1>(
    b_smem_stage_ptr, &groups[cur_group].tmap_b, ...);
tcgen05_2dtma_s2g_c(
    ..., &groups[cur_group].tmap_c, ...);

This removes the local_A_tmap / local_B_tmap / local_C_tmap SMEM copies, the tmap_fence_proxy, the d_tmaps GMEM cache, and the entire NK_VAR tmap_update_dim<...> block. The producer warp issues TMA loads directly against the constant-memory-resident descriptor.

V21 matches V19's runtime (38.2us on shape A) while preserving NK_VAR. It enables V22.

Benchmark Results: (same as v2.py below)

- V22 (v2.py) -

Adds a third TMEM result buffer, deepening the MMA<-epilogue pipeline from 2-buffered to 3-buffered:

tcgen05_alloc_tmem<1>(tmem_addr_ptr, 512);   // was TD_MMA_N * 2 * 2

const int tmem_result_ptrs[3] = {
    tmem_addr_base,
    tmem_addr_base + TD_MMA_N,
    tmem_addr_base + 2 * TD_MMA_N
};
const int tmem_sfa_ptrs[3] = {
    tmem_result_ptrs[2] + TD_MMA_N,
    tmem_result_ptrs[2] + TD_MMA_N + 16,
    tmem_result_ptrs[2] + TD_MMA_N + 32
};
const int tmem_sfb_ptrs[3] = {
    tmem_sfa_ptrs[2] + 16,
    tmem_sfa_ptrs[2] + 32,
    tmem_sfa_ptrs[2] + 48
};

// rotation: tmem_buf = (tmem_buf + 1) % 3, was tmem_buf ^= 1
// epi_phase[3], epi_done_phase[3]

// primed flag replaces tile_idx != blockIdx.x as the guard for
// waiting on mbar_addr_epi_done

With three TMEM result slots the MMA warp can keep issuing tcgen05.mma into a fresh slot while two prior tiles are still being consumed by the epilogue. This fully hides the TMEM->reg->half->SMEM->TMA store chain on wave-bound shapes.

The TMEM allocation grows from TD_MMA_N * 2 * 2 = 512 cols to a hard-coded 512 cols (= TD_MMA_N * 3 results + 3 * SF slots), so the TMEM budget is at the boundary; further deepening would require a different SF layout.

Benchmark Results:

shape A -> 38.2us shape B -> 25.2us shape C -> 12.8us shape D -> 9.1us

- V23 (v3.py) - Final -

A host-side launch-overhead micro-optimization. The previous versions extract A/B/C raw pointers from the per-group std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> on the C++ side, which costs a pybind11/ATen ABI hop and a refcount check per call. V23 pulls those data_ptr() calls up into Python where they're folded into a single list comprehension, then passes three uint64 CPU tensors:

# Python wrapper
A_ptrs = torch.tensor([a.data_ptr() for (a, b, c) in abc_tensors],
                      dtype=torch.uint64, device='cpu')
B_ptrs = torch.tensor([b.data_ptr() for (a, b, c) in abc_tensors],
                      dtype=torch.uint64, device='cpu')
C_ptrs = torch.tensor([c.data_ptr() for (a, b, c) in abc_tensors],
                      dtype=torch.uint64, device='cpu')
nvfp4_group_gemm(A_ptrs, B_ptrs, C_ptrs, sf_tensors,
                 prob_sizes, N, K, G)
// C++ entry
void nvfp4_group_gemm(torch::Tensor A_ptrs, torch::Tensor B_ptrs,
                      torch::Tensor C_ptrs, ...) {
    uint64_t* A_ptrs_data = A_ptrs.data_ptr<uint64_t>();
    // ... encode tmaps from raw pointers
    tma_3d_map_ab<...>::init(..., &gd.groups[i].tmap_a,
        reinterpret_cast<void*>(A_ptrs_data[i]), M[i], K);
}

Note that V5 (submission_v4_1.py) had measured this kind of change as net-zero. The difference here is that by V23 the kernel itself is fast enough (~9us on shape D) that even sub-microsecond host overhead changes are detectable. The kernel body, PTX, TMEM layout, mbarrier pipeline, and template instantiations are identical to V22 — the win comes entirely from launch-overhead reduction.

Benchmark Results:

shape A -> 38.2us shape B -> 25.2us shape C -> 12.8us shape D -> 9.1us

Group GEMM Summary

Optimization Summary (shapes A/B/C/D defined above the V9 entry):

  ┌─────────┬─────────────────────────────────────────────────────┬────────────────────────────────────────────────────────────────┐
  │ Version │                     Technique                       │                          Primary Gain                          │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v0      │ PyTorch baseline (per-GEMM torch._scaled_mm loop)   │ Correctness baseline (multi-ms runtime)                        │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v1      │ Single PTX kernel covering all groups               │ Eliminates G separate kernel launches (~200x vs baseline)      │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v2      │ Static persistent kernel (work-tile loop per CTA)   │ Foundation for cross-tile pipelining (small initial change)    │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v3      │ Persistent host/device buffers across calls         │ Removes allocator from the hot path (~50us)                    │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v4      │ Fixed-N/K specialization for benchmark shapes       │ Less metadata to track and transfer per call                   │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v5      │ std::vector<tuple> instead of torch.Tensor for ptrs │ Equivalent perf (pybind11 unpacks anyway)                      │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v6      │ tensormap.replace in-kernel + per-CTA GMEM cache    │ Tensormap setup no longer scales linearly in G (~80us)         │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v7      │ TMA store of C (cp.async.bulk.tensor SMEM->GMEM)    │ Biggest single-step win (-40us) — epilogue store accelerated   │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v8      │ Cluster Launch Control (CLC) dynamic work stealing  │ REGRESSION at these shapes: mechanism overhead > load balance  │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v9      │ Chunked double-buffered epilogue + pinned host mem  │ Overlap TMEM->SMEM conversion with prior TMA store (~5us)      │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v10     │ Defer C tensormap patch to epilogue path            │ Hide tensormap fence latency behind A/B loads + MMA            │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v11     │ Cross-tile TMA pipelining (monotonic glob_k_off)    │ Slight regression standalone — pays off when combined later    │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v12     │ Drop CLC, larger N_TILE (256), fewer pipe stages    │ Mixed: helps MMA-bound shapes, hurts small ones                │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v13     │ TMEM ping-pong (2 result halves)                    │ MMA(tile N+1) overlaps with epilogue(tile N)                   │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v14     │ GroupDescs as __grid_constant__ kernel arg          │ Eliminates cudaMemcpyAsync + host mallocs for descriptors      │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v15     │ Cross-tile TMA pipelining (on top of v14)           │ Modest win + uncovers SMEM/TMA proxy race (RAW via TMA store)  │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v16     │ TMEM ping-pong + cross-tile TMA pipelining          │ All three phases (TMA, MMA, epi) overlap across tile bounds    │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v17     │ 2-CTA cluster + TMA multicast of A + pre-encoded    │ Largest stack-on win in v6 family (~-3us on a 31us baseline)   │
  │         │ per-group tensormaps + dedicated tile-desc warp     │                                                                │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v18     │ Transposed-output epilogue (.32x32b tcgen05.ld)     │ Working baseline for transpose layout — not on multicast path  │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v19     │ Re-add tensormap.replace for NK_VAR flexibility     │ Slight regression (-3us) vs v17 — cost of generality           │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v20     │ Re-add multicast TMA of A on top of v19             │ Recovers most of v17's gain while keeping NK_VAR               │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v21     │ Pre-encoded per-group tmaps with per-group M, N, K  │ Same speed as v17 *with* NK_VAR support                        │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v22     │ 3-deep TMEM result buffer                           │ Hides full TMEM->reg->half->SMEM->TMA chain on wave-bound work │
  ├─────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v23     │ Pre-extract data_ptr()s in Python                   │ Skip per-tuple pybind11 unpacking on the C++ side (sub-us)     │
  │ (final) │                                                     │                                                                │
  └─────────┴─────────────────────────────────────────────────────┴────────────────────────────────────────────────────────────────┘

Broad Lessons:

The biggest single-step wins were structural (TMA store of C in V7, the v17 stack of cluster-multicast + pre-encoded tmaps + tile-desc warp) and host-side (GroupDescs as a __grid_constant__ arg in V14). The fine-grained pipelining work (TMEM ping-pong, cross-tile TMA pipelining, chunked epilogue) only made progress when stacked together — each in isolation typically regressed or showed a sub-microsecond gain.

Most correctness traps in this kernel were arrival-count or phase-bit invariants that needed to be updated alongside the producer-consumer arity. Adding TMA multicast without bumping mbarrier arrival counts to CLUSTER_SIZE causes deterministic hangs; adding an extra epilogue warp set without re-tallying named-barrier participant counts causes races; deepening TMEM ping-pong from 2 to 3 buffers requires a 3-state phase bookkeeping that mbarriers' single phase bit alone cannot represent. The kernel reaches its peak complexity around V17 and beyond; the broad lesson is that every change to "how many producers/consumers" requires re-deriving every related sync count and phase — there is no PTX-level safety net.

A second class of lessons concerns memory proxies: SMEM stores from the epilogue go through the generic proxy while a subsequent TMA load through the async proxy may read stale data. The debug in V15 took a long time because the PTX docs don't clearly flag this as a hazard (see 07_PTX_lessons.txt). Once you know to look for it, the fix is a single fence.proxy.async.shared::cta before each TMA store; before you know to look, it manifests as a non-deterministic correctness error that survives many failed reproduction attempts.

The host-side path matters more than expected for small problem shapes. By V23 the kernel runs at ~9us on shape D, which means even sub-microsecond host overhead (pybind11 tuple unpacking, descriptor encoding, malloc/free) is detectable. Several optimizations that measured as net-zero earlier in the lineage (V5: std::vector vs torch.Tensor for ptr passing) would have shown wins if measured at V23's runtime — they happened to be the wrong optimization at the wrong time.

The pre-encoded-per-group-tensormap approach (v5_precomp_tma, then V17, then V21) is a good illustration of the same idea failing then succeeding: in isolation the host-side encoding cost dominated, but once the kernel itself was fast enough and the host overhead was amortized across other improvements, that same change became a net win.

Stripped of the dead branches, the final winning kernel stacks the following independent improvements, each of which was discovered in isolation but only paid off when combined:

  1. Persistent buffers across calls (V3): no allocator on the
     hot path.
  2. Single kernel launch covering all GEMMs (V1), with
     fixed-N/K specialization for benchmark shapes (V4).
  3. Static persistent scheduling with 148 CTAs (V12, V14).
  4. Per-CTA TMA descriptor patching via `tensormap.replace`
     scaled poorly; pre-encoded per-group tensormaps embedded
     in a `__grid_constant__` GroupDescs struct scale to any G
     for free (V14 + V17/V21).
  5. TMA store of C via `cp.async.bulk.tensor` (V7) with
     OUT_N_CHUNK-chunked double-buffered SMEM staging using
     `cp.async.bulk.wait_group.read N` (V9) to overlap C
     conversion and TMA store.
  6. Cross-tile TMA pipelining via a monotonic glob_k_off
     counter and a first-tile prefetch (V11/V15).
  7. Multi-buffered TMEM results (2-deep in V13/V16, 3-deep in
     V22) so MMA on tile N+k overlaps epilogue on tile N.
  8. Two-CTA cluster with TMA multicast on A and
     cluster-scope `tcgen05.commit` (V17).
  9. Dedicated tile-descriptor producer warp that resolves
     (group, m_off, n_off, sfa, sfb, M) once per tile and
     publishes via SMEM (V17).

10. Host-side launch path: cudaMallocHost + pinned

     `cudaMemcpyAsync` (V9), then eliminated entirely once
     GroupDescs became a `__grid_constant__` argument (V14).

11. Pre-extract data_ptr()s in Python to skip the per-tuple

     pybind11 unpacking on the C++ side (V23).

The biggest single-step wins, in order: TMA store (V7, -40us), tensormap.replace removal + multicast + tile-descriptor warp (V17, -3us on a 31us baseline), and TMEM ping-pong (V16, -2us). The biggest correctness traps were all in the multi-CTA / multi-buffer space: arrive-count and phase-bit invariants must be updated together with every change to the producer-consumer arity.

Reach me at naregmegan@gmail.com