← NVFP4 Kernels View source on GitHub ↗

Dual GEMM

Dual GEMM Operation

The Dual GEMM kernel must implement the following operation:

C = silu(A @ B1) * (A @ B2)

silu is a scalar valued function that we apply to all elements of the operand matrix and is defined as follows:

silu(x) = x * sigmoid(x)

sigmoid(x) = 1/(1 + e^(-x))

In the dual GEMM operation the "*" sign refers to elementwise multiplication (Hadamard Product) and "@" refers to canonical GEMM. Thus, the dual GEMM operation takes the matrix product A @ B1, applies silu per element, and multiplies that result element wise by the matrix product A @ B2. Below are the benchmark problem shapes:

"m": 256, "n": 4096, "k": 7168, "l": 1
"m": 512, "n": 4096, "k": 7168, "l": 1
"m": 256, "n": 3072, "k": 4096, "l": 1
"m": 512, "n": 3072, "k": 7168, "l": 1

In this kernel all of the batch sizes are 1, so we don't need to worry about handling multiple batches. The A matrices are MxK and both of the B matrices are KxN.

Kernel Iteration and Code Walkthroughs

Now let's iterate through the different kernels I developed and examine the failures and successes. Afterwards we will look at an algorithm I tried to implement, but I couldn't get it to work, and my suspicion is that hardware doesn't allow what I wanted to do algorithmically.

- V0 (submission_ref.py) -

First lets look at the naive approach using PyTorch.

    # Call torch._scaled_mm to compute the GEMV result
    ref1 = torch.empty(
        (l, m, n),
        dtype=torch.float32,
        device="cuda",
    ).permute(1, 2, 0)
    ref2 = torch.empty(
        (l, m, n),
        dtype=torch.float32,
        device="cuda",
    ).permute(1, 2, 0)
    for l_idx in range(l):
        # Convert the scale factor tensor to blocked format
        scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx])
        scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx])
        scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx])
        # (m, k) @ (n, k).T -> (m, n)
        res1 = torch._scaled_mm(
            a_ref[:, :, l_idx],
            b1_ref[:, :, l_idx].transpose(0, 1),
            scale_a.cuda(),
            scale_b1.cuda(),
            bias=None,
            out_dtype=torch.float32,
        )
        ref1[:, :, l_idx] = res1

        res2 = torch._scaled_mm(
            a_ref[:, :, l_idx],
            b2_ref[:, :, l_idx].transpose(0, 1),
            scale_a.cuda(),
            scale_b2.cuda(),
            bias=None,
            out_dtype=torch.float32,
        )
        ref2[:, :, l_idx] = res2
    # Do silu on the first GEMM result and multiply with the second GEMM result
    c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16)
    return c_ref

This solution just uses normal block-scaled gemm to compute the dual gemm, performing two separate torch._scaled_mm calls and reducing the results in a third step. Although this code is very simple, it comes with a number of disadvantages we can optimize. First, by separating the operation into three separate calls into torch internals we are essentially issuing three kernel calls. This creates extra kernel launch overhead. Second, we can't re-use any of the data loaded from the A matrix into SMs on the GPU, resulting in a potentially very large amount of unnecessary memory traffic. On the same theme performing the reduction out of partial results stored in GMEM means much more traffic in and out of GMEM as well as a larger required GMEM footprint required by the operation. This kernel takes the first of two potential approaches to implementing the dual GEMM operation. We discuss these two approaches in detail in V1.

The benchmark data for this baseline kernel is shown below:

submission_ref.py

{'m': 256, 'n': 4096 'k': 7168} -> Mean: 91.2us
{'m': 512, 'n': 4096 'k': 7168} -> Mean: 95.3us
{'m': 256, 'n': 3072 'k': 4096} -> Mean: 73.8us
{'m': 512, 'n': 3072 'k': 7168} -> Mean: 87.5us

- V1 (submission_ptx.py) -

In this kernel we build our own kernel in PTX similar to later kernels in GEMV and the PTX GEMM kernel from the last section. If engineering simplicity and maintainability were a concern we would want to consider other options first such as using PyTorch native optimization options like torch.compile or Helion. However, since the goal is to reduce latency at all costs we jump straight to the paradigm offering the most granular control and thus potential for optimization.

There are two key differences dual GEMM introduces on top of normal GEMM. First, we are multiplying the A matrix by two different B matrices. Second, we apply a function (or "activation" in ML terms) on the results of one of the matrix multiplications. When deciding how to design our initial kernel we really have two broad algorithmic options: keep the same paradigm that we had in GEMM where each CTA is responsible for a single block of GEMM output, or we could have each CTA compute a block of the final result after combining the result of the two GEMM operations it has computed.

The first approach would allow us to use our previous GEMM kernel pretty much directly; however, we would need to "reduce" the results from the inidividual GEMMs (like we did back in the GEMV kernel). This reduction introduces a number of disadvantages: kernel launch overhead for the reduction kernel, 2 additional trips for the data from the two GEMM operations, one for the GEMM kernel to store the intermediate results and one for the reduction kernel to load that data back from GMEM (L2 does help reduce some of this cost, but not entirely and L2 is still very costly relative to registers and SMEM). Another disadvantage is that the A matrix would be loaded from GMEM twice because every block of the final output is computed from the result of two CTAs computing A @ B1 and A @ B2. Again, L2 will help reduce the cost of repeatedly loading data, but this is still costly and repetitive.

The second approach eliminates all of the aforementioned disadvantages in exchange for a few limitations we will discuss in a moment. In this approach each CTA is responsible for a final result output block (as opposed to just a block of output for a single GEMM). This means each CTA needs to perform two GEMM computations simultaneously and reduce those two results at the end of the kernel. The trade offs now become apparent: each CTA will take longer to run than the CTAs in the previous versions GEMM kernels because there is simply more work being done per CTA. Additionally, each CTA will compute smaller blocks of output in order to fit the results of two simultaneous MMAs in TMEM. These trade offs are well worth it because at the end of the day both kernels are performing the same work, so an increase in CTA runtime ultimately will never take longer than the linear path of doing the individual GEMMs and reducing the result in a separate kernel. Additionally, if we keep hardware utilizied well we shouldn't have to worry about computing smaller blocks of output since we are still utilizing the same (or more) of the hardware per unit time and structure of the waves launched on the GPU will be the same or very similar (i.e. number of CTAs in kernel will remain constant, this will be become clear as we dive deeper into the design of this kernel).

Luckily implementing this second approach only requires a few simple modifications on top of the GEMM kernel from the last section. First we need to duplicate everything we do with the B matrix in GEMM so we do each operation twice. In other words we load two B tiles for B1 and B2, and two scale factor B tiles for SFB1 and SFB2. Additionally, we perform two nvfp4 mma operations instead of just one (one for each of A @ B1 and A @ B2). The second modification is the application of the silu operation and the hadamard multiplication (element-wise multiplication) after the TMA/MMA loop has completed before storing back to GMEM.

Taking a look at these modifications in code:

Observe the duplicate load operations for B1 and B2.

    tcgen05_3dtma_g2s_ab<1>(a_smem_stage_ptr, &tmap_a, m_off, k_off_coremat, mbar_addr_tma_stage, CacheHintSm100::EVICT_LAST);
    tcgen05_3dtma_g2s_ab<1>(b1_smem_stage_ptr, &tmap_b1, n_off, k_off_coremat, mbar_addr_tma_stage, CacheHintSm100::EVICT_FIRST);
    tcgen05_3dtma_g2s_ab<1>(b2_smem_stage_ptr, &tmap_b2, n_off, k_off_coremat, mbar_addr_tma_stage, CacheHintSm100::EVICT_FIRST);
    /*
        Scale factors are stored in global memory in 4x4x32 chunks, i.e. 512B chunks where each chunk represents a
        128x4 chunk of the SF matrix (in M or N xK)
        So we calculate the offset in each dimension in terms of these 512B chunks:
        k_off / 64 represents the number of 128x4 (512B) chunks along the K dimension which are contiguous (4 * SF_BLOCKS_SIZE = 64)
        m/n_off / 128 represents the number of 512B chunks along the M dimension which are strided by K / 64 512B chunks
    */
    const uint8_t* sfa_g_ptr = sfa_gmem_base + ((k_off / 64) + (m_off / 128) * (K / 64)) * 512; // ISSUE: These could be just simple bit shifts, adjust if compiler doesn't
    const uint8_t* sfb1_g_ptr = sfb1_gmem_base + ((k_off / 64) + (n_off / 128) * (K / 64)) * 512;
    const uint8_t* sfb2_g_ptr = sfb2_gmem_base + ((k_off / 64) + (n_off / 128) * (K / 64)) * 512;
    tcgen05_1dtma_g2s_sf(sfa_smem_stage_ptr, sfa_g_ptr, SFA_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_LAST);
    tcgen05_1dtma_g2s_sf(sfb1_smem_stage_ptr, sfb1_g_ptr, SFB1_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_FIRST);
    tcgen05_1dtma_g2s_sf(sfb2_smem_stage_ptr, sfb2_g_ptr, SFB2_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_FIRST);

We need to copy two sets of scale factors from SMEM to TMEM.

    // Load scale factors SMEM -> TMEM
    for (int sub_k_iter = 0; sub_k_iter < TD_SMEM_K / TD_MMA_K; sub_k_iter++) {
        uint64_t sfa_desc = make_smem_desc<0, false>(sfa_smem_stage_ptr + (sub_k_iter * 512)); // ISSUE: verify this should input 0 here
        uint64_t sfb1_desc = make_smem_desc<0, false>(sfb1_smem_stage_ptr + (sub_k_iter * 512));
        uint64_t sfb2_desc = make_smem_desc<0, false>(sfb2_smem_stage_ptr + (sub_k_iter * 512));
        tcgen05_cp<1>(tmem_addr_sfa + 4 * sub_k_iter, sfa_desc);
        tcgen05_cp<1>(tmem_addr_sfb1 + 4 * sub_k_iter, sfb1_desc);
        tcgen05_cp<1>(tmem_addr_sfb2 + 4 * sub_k_iter, sfb2_desc);
    }

Two MMA operations.

    tcgen05_mma_nvfp4<1>(tmem_addr_result_1, a_desc, b1_desc, make_instr_desc<TD_MMA_M, TD_MMA_N>(), sfa_tmem, sfb1_tmem, k_off + sub_k_iter);
    tcgen05_mma_nvfp4<1>(tmem_addr_result_2, a_desc, b2_desc, make_instr_desc<TD_MMA_M, TD_MMA_N>(), sfa_tmem, sfb2_tmem, k_off + sub_k_iter);

Lastly we look at how to apply the silu and hadamard reduction operation to the two computed MMA tiles.

    mbar_wait(mbar_addr_epi, 0);
    asm volatile("tcgen05.fence::after_thread_sync;");

    // Load MMA into regs (TMEM -> Regs)
    // Each warp loads 16 rows per tcgen05_ld and we have 128 rows, with 4 warps each one is responsible for 32 rows
    // Each thread loads TD_MMA_N/2 values per result, meaning we need TD_MMA_N regs in total to load each result
    float results_1[TD_MMA_N/2];
    float results_2[TD_MMA_N/2];
    int rows_per_warp = TD_MMA_M / (NUM_WARPS - 2);
    for (int sub_m = 0; sub_m <  rows_per_warp / 16; sub_m++) {
        if constexpr (TD_MMA_N == 128) {
            tcgen05_ld<16, 256, 16>(results_1, tmem_addr_result_1 + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
            tcgen05_ld<16, 256, 16>(results_2, tmem_addr_result_2 + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
        }
        else if constexpr (TD_MMA_N == 64) {
            tcgen05_ld<16, 256, 8>(results_1, tmem_addr_result_1 + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
            tcgen05_ld<16, 256, 8>(results_2, tmem_addr_result_2 + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
        }
        asm volatile("tcgen05.wait::ld.sync.aligned;");

        // Post process and store from Regs to SMEM (Regs -> SMEM)
        // Transfer result from SMEM -> GMEM (8 comes from 256/32 -> 256b per ld block from above)
        float result[4];
        for (int i = 0; i < TD_MMA_N / 8; i++) {
            const int m_offset = m_off + warp_id * rows_per_warp + sub_m * 16 + lane_id / 4;
            const int n_offset = n_off + i * 8 + (lane_id % 4) * 2;

            result[0] = silu(results_1[i * 4]) * results_2[i * 4];
            result[1] = silu(results_1[i * 4 + 1]) * results_2[i * 4 + 1];
            result[2] = silu(results_1[i * 4 + 2]) * results_2[i * 4 + 2];
            result[3] = silu(results_1[i * 4 + 3]) * results_2[i * 4 + 3];

            reinterpret_cast<half2 *>(c_ref + (m_offset)*N + n_offset)[0] = __float22half2_rn({result[0], result[1]});
            reinterpret_cast<half2 *>(c_ref + (m_offset + 8)*N + n_offset)[0] = __float22half2_rn({result[2], result[3]});
        }
    }

We declare two register buffers to hold the two separate MMA results: results_1 and results_2. We then iteratively load each MMA result in chunks as we did in the GEMM kernel. Once each chunk has been loaded into their respective buffers we loop through those register buffers and do the following: apply the silu operation to the values in results_1, multiply that value by the positionally equivalent value in results_2, and store that end result as a FP16 value to GMEM.

Below are the benchmark results of this new PTX kernel. We see a near 5x speed-up relative to the reference kernel!

submission_ptx.py

{'m': 256, 'n': 4096 'k': 7168} -> Mean: 18.6us
{'m': 512, 'n': 4096 'k': 7168} -> Mean: 18.6us
{'m': 256, 'n': 3072 'k': 4096} -> Mean: 14.4us
{'m': 512, 'n': 3072 'k': 7168} -> Mean: 18.5us

- V2 (sub_ptx_v2.py) -

One hardware feature that we have yet to take advantage of is "2SM" or "CTA_GROUP = 2". There isn't really a concise "name" for this hardware feature. The idea is to have 2 CTAs work together on a single tcgen05.mma instruction [see CTA_GROUP Details]. Ultimately what this does is allow CTAs to compute MMA operations without having all of the necessary data resident in their local SMEM. This reduces the amount of required SMEM per stage of the software pipeline, which in turn allows us to increase the depth of our software pipeline and/or the occupancy of each SM.

Up until now we have hardcoded CTA_GROUP to be 1. The changes to take advantage of CTA_GROUP = 2 are rather simple:

In the kernel definition we need to include a __cluster_dims__(CTA_GROUP, 1, 1) directive so the GPU knows we are launching this kernel with the CTA grid broken down into (2, 1, 1) clusters [see Cluster Details]. Two CTAs can only communicate their data to eachother for the 2SM MMA if they are in the same cluster.

template<int TD_CTA_M, int TD_CTA_N,
         int TD_SMEM_M, int TD_SMEM_N, int TD_SMEM_K, 
         int TD_MMA_M, int TD_MMA_N, int TD_MMA_K, bool SWIZZLE, int PIPE_STAGES, int NUM_WARPS, int CTA_GROUP>
__global__ void __cluster_dims__(CTA_GROUP, 1, 1) nvfp4_dual_gemm_kernel(__half* __restrict__ c_ref ...

Next if CTA_GROUP is 2 each CTA in the cluster only needs to load half of the data for the B1 and B2 matrices, hence why we divide by CTA_GROUP.

    constexpr int B1_SMEM_TILESZ = (TD_SMEM_N / CTA_GROUP) * (TD_SMEM_K / 2); // 4KB
    constexpr int B2_SMEM_TILESZ = (TD_SMEM_N / CTA_GROUP) * (TD_SMEM_K / 2); // 4KB

We need to make sure all threads in the cluster are synchronized, so instead of just __syncthreads() we need to use a barrier whose scope is the entire cluster.

    else if constexpr (CTA_GROUP == 2) {
        asm volatile("barrier.cluster.arrive.release.aligned;");
        asm volatile("barrier.cluster.wait.acquire.aligned;"); // All threads in cluster wait until all have executed arrive instr.
    }

All that remains is to pass CTA_GROUP which is 2 to all of the relevant PTX instructions we've previously discussed. In the case of TMA instructions it allows them to properly account for data arriving to both of the clusters. In the case of MMA instructions it instructs the hardware on where to find the missing half of data for B1 and B2. The Modular article I've mentioned has a great explanation of 2SM/CTA_GROUP = 2 as well [Modular article link].

The previous kernel after implementing CTA_GROUP = 2 was the best I was able to develop in time to submit to the competition. However, after reviewing some of the submissions that performed better than mine I realized a key detail in the tcgen05.mma instruction that I had overlooked: the collector_usage field. The collector_usage field allows the user to specify whether or not data in the collector buffer is stale or can be used for future tcgen05.mma operations as well as what can be done with the data in the collector buffer. The collector buffer is essentially a data cache specific to the tcgen05.mma hardware.

There are four possible values for collector_usage with respect to the A matrix (there are also similar options for the B matrix, but we don't use those):

1) .collector::a::fill -> Specifies that the A matrix read from the memory should be filled in collector buffer.
2) .collector::a::use -> Specifies that the A matrix can be read from the collector buffer. This requires a previous fill to the collector buffer to be still valid.
3) .collector::a::lastuse -> Specifies that the A matrix can be read from the collector buffer and the contents of the collector buffer can be discarded. This requires a previous fill to the collector buffer to be valid till the collector buffer is read.
4) .collector::a::discard -> Specifies that the contents of the collector buffer for A can be discarded.

The default option for tcgen05.mma instructions is .collector::a::discard, meaning the contents of the collector buffer are immediately invalidated. However, if two tcgen05.mma operations re-use the same A matrix data we can keep that data resident in the A matrix collector buffer, allowing a decent speed-up by not having to load the data into the collector buffer from SMEM. The following code demonstrates this:

    tcgen05_mma_nvfp4<CTA_GROUP, COLLECTOR_USAGE::A_FILL>(tmem_addr_result_1, a_desc, b1_desc, make_instr_desc<TD_MMA_M*CTA_GROUP, TD_MMA_N>(), sfa_tmem, sfb1_tmem, k_off + sub_k_iter);
    tcgen05_mma_nvfp4<CTA_GROUP, COLLECTOR_USAGE::A_LASTUSE>(tmem_addr_result_2, a_desc, b2_desc, make_instr_desc<TD_MMA_M*CTA_GROUP, TD_MMA_N>(), sfa_tmem, sfb2_tmem, k_off + sub_k_iter);

The first tcgen05_mma_nvfp4 uses .collector::a::fill and the second uses .collector::a::lastuse. This combination means the first MMA fills the collector buffer while the second just reuses that data. It's worth noting that usage of the collector buffer is ultimately decided by hardware. Software can only provide what amount to "strong hints" as to how sequenced tcgen05.mma instructions can be optimized. With this optimization we see decent speed-ups on the first and third benchmarks, with the second and fourth staying about the same. The key difference seems to be the size of the m-dimension, but I haven't figured out why that dictates whether or not the problem shape benefits from the collector usgae speedup [Open Problem].

sub_ptx_v2.py

{'m': 256, 'n': 4096 'k': 7168} -> Mean: 14.4us
{'m': 512, 'n': 4096 'k': 7168} -> Mean: 18.5us
{'m': 256, 'n': 3072 'k': 4096} -> Mean: 10.3us
{'m': 512, 'n': 3072 'k': 7168} -> Mean: 18.3us

The following kernel iterations were experiments that didn't end up improving performance.

- V3, Regression (sub_ptx_splitk_slow) -

The k-dimension for these benchmark shapes are quite large relative to the M-dimension. A natural algorithm to consider for these kinds of problem shapes is split-k.

We implemented split-k in the GEMV kernel; however, in that case the work along the k-dimension was split across threads in the same CTA, meaning the reduction of partial results could be done out of a workspace in SMEM before the final result is stored to GMEM. In this dual GEMM kernel the work would have to be split across separate CTAs (unless we figure out a way to use DSMEM, which will be discussed in V5) and thus the reduction of partial results needs to be done out of GMEM, and the reduction would require a separate kernel launch. Additionally, we need to allocate separate partial result storage buffers in GMEM, increasing the space in GMEM this kernel demands.

As we've discussed kernel launch overhead (and overhead in general) is more costly for smaller problem sizes because it is a larger portion of the overall runtime. Since these problem sizes are relatively small I suspected a split-k algorithm working out of GMEM and requiring a separate reduction kernel would likely result in performance loss rather than gain; however, I decided to implement it just to be sure.

    const int k_per_split = K / SPLIT_K;

    CUtensorMap tmap_a, tmap_b1, tmap_b2, tmap_sfa, tmap_sfb1, tmap_sfb2;
    constexpr CUtensorMapSwizzle SWIZZLE_TYPE = SWIZZLE ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
    tma_3d_map_ab<M_TILE_SIZE, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled, &tmap_a, a_ref.data_ptr(), M, K);
    tma_3d_map_ab<N_TILE_SIZE/CTA_GROUP, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled, &tmap_b1, b1_ref.data_ptr(), N, K);
    tma_3d_map_ab<N_TILE_SIZE/CTA_GROUP, K_TILE_SIZE, SWIZZLE_TYPE>::init(cuTensorMapEncodeTiled, &tmap_b2, b2_ref.data_ptr(), N, K);

    tma_3d_map_sf<K_TILE_SIZE>(cuTensorMapEncodeTiled, &tmap_sfa, sfa_ref.data_ptr(), M, K);
    tma_3d_map_sf<K_TILE_SIZE>(cuTensorMapEncodeTiled, &tmap_sfb1, sfb1_ref.data_ptr(), N, K);
    tma_3d_map_sf<K_TILE_SIZE>(cuTensorMapEncodeTiled, &tmap_sfb2, sfb2_ref.data_ptr(), N, K);

    auto kernel_inst = nvfp4_dual_gemm_kernel<M_TILE_SIZE, N_TILE_SIZE, M_TILE_SIZE, N_TILE_SIZE, K_TILE_SIZE, M_TILE_SIZE, N_TILE_SIZE, K_MMA_SIZE, SWIZZLE, PIPE_STAGES, NUM_WARPS, CTA_GROUP>;

    cudaFuncSetAttribute(kernel_inst, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxShared);

    constexpr int threads = WARP_SIZE * NUM_WARPS;
    dim3 grid_dim(M / M_TILE_SIZE, N / N_TILE_SIZE, SPLIT_K);

    kernel_inst<<<grid_dim, threads>>>(
        reinterpret_cast<float*>(workspace_1.data_ptr()),
        reinterpret_cast<float*>(workspace_2.data_ptr()),
        M, N, K, k_per_split,
        tmap_a, tmap_b1, tmap_b2, tmap_sfa, tmap_sfb1, tmap_sfb2);

    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        throw std::runtime_error(cudaGetErrorString(err));
    }

    // Reduction kernel: sum partials across splits and apply epilogue fusion
    dim3 reduce_block(16, 16);
    dim3 reduce_grid((M + 15) / 16, (N + 15) / 16);

    splitk_reduce_kernel<<<reduce_grid, reduce_block>>>(
        reinterpret_cast<float*>(workspace_1.data_ptr()),
        reinterpret_cast<float*>(workspace_2.data_ptr()),
        reinterpret_cast<__half*>(c_ref.data_ptr()),
        M, N, SPLIT_K);

We define a divisor SPLIT_K to divide up the k-dimension. Each CTA then handles a k_per_split sized chunk along k, storing it in a temporary workspace: workspace_1 for A @ B1 and workspace_2 for A @ B2. Once the partial results have been computed and stored to the GMEM workspaces, a second kernel splitk_reduce_kernel is called to perform the silu, hadamard, and final GMEM result store.

Nothing changes about the dual GEMM kernel except the k-loop only iterates through k_per_split elements instead of all of K, and the epilogue warp just stores the GEMM results to the respective workspaces instead of doing the silu/hadamard, which is done by the below reduction kernel:

// Split-K reduction kernel: sum partials and apply epilogue fusion
__global__ void splitk_reduce_kernel(
    const float* __restrict__ workspace_1,  // [SPLIT_K][M][N]
    const float* __restrict__ workspace_2,  // [SPLIT_K][M][N]
    __half* __restrict__ output,            // [M][N]
    const int M, const int N, const int SPLIT_K
) {
    const int m = blockIdx.x * blockDim.x + threadIdx.x;
    const int n = blockIdx.y * blockDim.y + threadIdx.y;
    if (m >= M || n >= N) return;

    const int mn_stride = M * N;
    const int base_idx = m * N + n;

    // Sum partials across splits
    float sum_1 = 0.0f, sum_2 = 0.0f;
    for (int s = 0; s < SPLIT_K; s++) {
        const int idx = s * mn_stride + base_idx;
        sum_1 += workspace_1[idx];
        sum_2 += workspace_2[idx];
    }

    // Apply epilogue fusion: silu(sum_1) * sum_2
    float result = silu(sum_1) * sum_2;
    output[base_idx] = __float2half(result);
}

This kernel is very straightforward. Each thread is responsible for reducing a single final result, iterating across the split-k results, summing them, and performing the silu and multiplication on the results. Then it stores the result to the correct place in GMEM as an FP16.

Below are the results for this split-k kernel using SPLIT_K = 4. As expected performance regresses significantly. We have added a large amount of GMEM traffic and kernel launch overhead, for a minor gain in parallelizing across the k-dimension. The performance gain by parallelizing split-k is small because we don't actually increase hardware utilization all that much for these problem shapes.

sub_ptx_splitk_slow.py

{'m': 256, 'n': 4096 'k': 7168} -> Mean: 45.2us
{'m': 512, 'n': 4096 'k': 7168} -> Mean: 83.5us
{'m': 256, 'n': 3072 'k': 4096} -> Mean: 41.0us
{'m': 512, 'n': 3072 'k': 7168} -> Mean: 68.9us

- V4 (sub_ptx_v2_cudalaunch) -

In this kernel we just modify how the CUDA kernel actually gets launched. We use the cudaLaunchKernelEx API instead of the triple angle bracket notation. The kernel launch looks like the below:

    cudaLaunchConfig_t config = {0};
    config.gridDim.x = M/M_TILE_SIZE;
    config.gridDim.y = N/N_TILE_SIZE;
    config.gridDim.z = 1;
    config.blockDim.x = WARP_SIZE * NUM_WARPS;
    config.blockDim.y = 1;
    config.blockDim.z = 1;

    // new attribute 
    cudaLaunchAttribute attribute[1]; // only one attribute in this case
    attribute[0].id = cudaLaunchAttributeClusterDimension; // specify attribute type
    attribute[0].val.clusterDim.x = CTA_GROUP;
    attribute[0].val.clusterDim.y = 1;
    attribute[0].val.clusterDim.z = 1;

    // add our attribute to the config
    config.numAttrs = 1;
    config.attrs = attribute;

    cudaLaunchKernelEx(&config, kernel_inst, reinterpret_cast<__half*>(c_ref.data_ptr()), M, N, K, tmap_a, tmap_b1, tmap_b2, tmap_sfa, tmap_sfb1, tmap_sfb2); // launch our kernel with the config

The benchmark results are the same as the most optimal kernel as one might expect. We didn't change anything about the kernel itself, just the launch mechanism.

sub_ptx_v2_cudalaunch.py

{'m': 256, 'n': 4096 'k': 7168} -> Mean: 14.4us
{'m': 512, 'n': 4096 'k': 7168} -> Mean: 18.5us
{'m': 256, 'n': 3072 'k': 4096} -> Mean: 10.3us
{'m': 512, 'n': 3072 'k': 7168} -> Mean: 18.3us

- V5, Regression (sub_ptx_v2_splitk, sub_ptx_v3) -

I mentioned in the split-k implementation how one of the main slowdowns was the introduction of much more GMEM traffic and the overhead of an extra kernel launch. The Blackwell architecture does offer a potential way to minimize both of these slowdowns via DSMEM [DSMEM Details]. Instead of distributing the work along the k-dimension across independent CTAs, we distribute the work along CTAs in a single cluster and store the temporary results in DSMEM rather than GMEM. Once all partial results have been computed they get reduced cluster-locally in the same kernel and stored to GMEM. This avoids extra GMEM traffic and kernel launch overhead while still reaping the benefits of k-dimension parallelism.

sub_ptx_v2_splitk.py / sub_ptx_v3.py are my attempts at getting this algorithm to work in combination with 2SM; however, it's unclear to me if using both of these hardware features in this way is even allowed or intended by the hardware [Open Problem].

The idea is to launch the grid as clusters of shape (CTA_GROUP, SPLIT_K, 1). When CTA_GROUP = 2 the cluster pairs work together to compute a 256xN chunk of the output for a K / SPLIT_K chunk along the k-dimension. The key would be that these partial results aren't stored exclusively in the SMEMs local to the CTA pair. Instead the results are split up into SPLIT_K chunks and each chunk is added to it's associated buffer in DSMEM of the cluster. So if we split the N dimension amongst the CTA pairs in the cluster for N = 256 and SPLIT_K = 4, each CTA pair would store a block of 256x64 in their local SMEMs. Once all CTAs in the cluster have finished computing and accumulating their results into the DSMEM buffers of all CTAs in the cluster, each CTA in the cluster can reduce their local chunk of results and store to GMEM.

Intended cluster-local split-K layout (cluster of CTA_GROUP × SPLIT_K = 2 × 4 CTAs):

  Cluster grid (rows = split-K shards, cols = 2SM CTA-pair partners):

                       N-partner 0         N-partner 1
                    ┌────────────────┐  ┌────────────────┐
        K-shard 0   │ CTA(0,0)       │  │ CTA(1,0)       │  K[0   .. K/4)
                    │  → 256 × 64    │  │  → 256 × 64    │
                    │  partial chunk │  │  partial chunk │
                    └────────────────┘  └────────────────┘
        K-shard 1   │ CTA(0,1)       │  │ CTA(1,1)       │  K[K/4 .. K/2)
                    └────────────────┘  └────────────────┘
        K-shard 2   │ CTA(0,2)       │  │ CTA(1,2)       │  K[K/2 .. 3K/4)
                    └────────────────┘  └────────────────┘
        K-shard 3   │ CTA(0,3)       │  │ CTA(1,3)       │  K[3K/4 .. K)
                    └────────────────┘  └────────────────┘

  After local MMA, each CTA owns one 256×64 partial along its K-shard.
  The reduction phase would scatter partials into DSMEM peers within
  the cluster, sum across shards, then store the final 256×64 to GMEM:

         DSMEM exchange within cluster (no GMEM workspace, no extra kernel)

                 ┌─────┐    ┌─────┐    ┌─────┐    ┌─────┐
                 │CTA  │←──►│CTA  │←──►│CTA  │←──►│CTA  │
                 │(*,0)│    │(*,1)│    │(*,2)│    │(*,3)│
                 └─────┘    └─────┘    └─────┘    └─────┘
                  ↘            ↓            ↓           ↙
                              reduce locally
                              and store to GMEM

Dual GEMM Summary

Optimization Summary:

  ┌──────────────────┬───────────────────────────────────────────┬────────────────────────────────────────────────────────────────┐
  │     Version      │                 Technique                 │                          Primary Gain                          │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v0 (ref)         │ PyTorch baseline (3 kernel calls)         │ Correctness baseline                                           │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v1 (ptx)         │ Single PTX kernel: fused A@B1 + A@B2 +    │ Eliminates two kernel launches, A is loaded once per CTA,      │
  │                  │ silu + Hadamard per CTA                   │ no GMEM round-trip for intermediates (~5x vs baseline)         │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v2 (ptx_v2)      │ 2SM / CTA_GROUP=2 cluster MMA             │ Halves per-CTA SMEM for B1/B2; enables deeper pipelining       │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v2_collector     │ collector::a::fill / lastuse on MMA pair  │ Avoids re-loading A into the MMA collector buffer between the  │
  │ (post-comp)      │ that share A                              │ A@B1 and A@B2 MMAs (~4us on shapes A, C)                       │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v3 (splitk_slow) │ Split-K via separate reduction kernel     │ REGRESSION: GMEM workspace + extra launch dominates the gain   │
  │                  │                                           │ from k-dim parallelism at these problem sizes                  │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v4 (cudalaunch)  │ cudaLaunchKernelEx instead of <<<...>>>   │ Equivalent perf — proves launch API doesn't matter here        │
  ├──────────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤
  │ v5 (splitk DSMEM)│ Cluster-local split-K with DSMEM-resident │ DID NOT WORK: combining 2SM and cluster-local split-K appears  │
  │                  │ partial reductions                        │ to exceed what the hardware allows [Open Problem]              │
  └──────────────────┴───────────────────────────────────────────┴────────────────────────────────────────────────────────────────┘

Broad Lessons:

When fusing multiple operations that share an input (here A is shared by both GEMMs and the silu/Hadamard combines the two outputs), the right axis to "fuse along" is the one that maximizes data reuse from the cache/registers/SMEM of a single CTA. Reducing in GMEM with a separate kernel (the v0 baseline) loads A twice and pays for two extra kernel launches; reducing in registers/TMEM inside one CTA (v1) keeps A SMEM-resident, performs the silu/Hadamard in registers, and writes only the final FP16 result back to GMEM. The trade-off is per-CTA work grows and each CTA needs more TMEM (two result accumulators instead of one), but this is well-amortized as long as the GPU stays well-utilized.

The collector_usage field is a good example of a hardware feature whose effect is unintuitive from the docs alone. Marking the two MMAs in a pair as fill then lastuse lets the second MMA reuse A from the on-chip collector buffer instead of re-reading from SMEM. The size of the resulting speedup depended unpredictably on M — small-M shapes benefited and large-M shapes didn't — which is one of the [Open Problems] left in this writeup.

For small problem shapes the introduction of an extra kernel launch (e.g. the split-K reduction kernel in v3) often costs more than the parallelism gain. The cost-benefit calculation flips as problem sizes grow, but at the sub-100us scale of these benchmarks any algorithm requiring a separate launch + GMEM workspace round-trip is structurally disadvantaged.

Reach me at naregmegan@gmail.com