Batched GEMV
The first problem was Batched GEMV (Batches of GEneral Matrix-Vector multiplies). The kernel takes as input a list of L, MxK matrices and a list of L, K length vectors. The output is L vectors of size M, each the result of matrices[i] * vectors[i]. All storage that backs the inputs is located on device already, so we don't need to worry about host to device data movement or allocation in device memory. The kernel is tested for correctness on a variety of shapes, but the benchmark for the competition was based on the below problem shapes:
"k": 16384, "m": 7168, "l": 1 "k": 7168, "m": 4096, "l": 8 "k": 2048, "m": 7168, "l": 4
A and SFA refer to the input matrices and their associated scale matrices. B and SFB refer to the vectors and their associated scale vectors. C refers to the output vectors. The result is expected to be in the form of half precision floats (FP16).
V0
We start off with the naive approach to this problem which is to iterate through each of the L matrix vector pairs and perform the matrix-vector multiply, storing the result in the C vectors.
# Call torch._scaled_mm to compute the GEMV result
for l_idx in range(l):
# Convert the scale factor tensor to blocked format
scale_a = to_blocked(sfa_ref[:, :, l_idx])
scale_b = to_blocked(sfb_ref[:, :, l_idx])
# (m, k) @ (n, k).T -> (m, n)
res = torch._scaled_mm(
a_ref[:, :, l_idx],
b_ref[:, :, l_idx].transpose(0, 1),
scale_a,
scale_b,
bias=None,
out_dtype=torch.float16,
)
c_ref[:, 0, l_idx] = res[:, 0]
See the full kernel in nvfp4_gemv/v0.py
Here are the benchmark numbers using this kernel:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 82.8us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 327.6us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 131.5us
Benchmark aside:
This kernel prepares the scale factor data using the to_blocked function, which reformats the scale factor matrices and vectors in a way that the PyTorch internal function _scaled_mm can ingest. In the GEMM kernel I discuss the formatting details of scale factors in much greater detail, but for further kernels in this section it's an unnecessarily convoluting detail that we will omit.
The main issue with this kernel is that it prepares the data for each individual batch and then launches a new kernel. Kernel launches incur significant overhead if the work that kernel does completes on the order of 10s or 100s of microseconds. In these cases the kernel launch itself can makeup a non-trivial portion of the total runtime of the kernel.
Kernel launches are expensive for these timescales because they can take on the order of a few microseconds depending on the size of the kernel parameters [1]. When you launch a kernel the CUDA runtime API calls into the CUDA driver API to prepare and transmit the kernel parameters over PCIe, check kernel launch parameters to ensure they are feasible and there are enough hardware resources to serve what the kernel requests, and lastly schedules the kernel to be run on a hardware consumed queue on which exist all current valid kernels scheduled to be run.
The solution to this issue is kernel fusion [Operator/Kernel Fusion Details]. Instead of launching a separate kernel for each matrix-vector multiplication we launch a single kernel and have it process all batches at once. This allows us to use hardware more effectively and avoids the cost of repetitive kernel invocations. v1.py implements the first version of this type of kernel.
V1
In this kernel we specify a number of threads per block [Threads/CTAs Details] and then launch enough blocks so there is a single thread computing each of the M*L results.
const int blocks = (m*l + threads - 1) / threads;
Each thread then loops across the k-dimension loading the A/B values and scale factor A (SFA) and scale factor B (SFB) from GMEM, converting them to half precision floats, multiplying them together, and accumulating them into a result value. After the k-loop that result is converted from FP32 to FP16 and stored to the result location in GMEM.
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 831.2us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 575.0us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 174.2us
Our kernel performance degraded in the cases where we would expect the most improvement. As is the case in all of science and engineering, kernel design and optimization rarely progresses monotonically in the desired direction. As can also be seen in the development of the other kernels below, you have to prototype different ideas to see how they interact with hardware in order uncover deeper understanding of what's really happening, after which you can more accurately engineer/optimize. In this specific case the reason for the major slowdown wasn't clear to me until the end of the optimization process at which point I uncover what the biggest bottleneck to this kernel was all along.
My first thought was to look at the problem shapes and how hardware is utilized for those shapes [Hardware Utilization Details]. In the first problem shape K is very large relative to how many threads get launched. If each CTA has 128 threads that results in 7168/128 = 56 CTAs launched. Even if we assume our occupancy [Occupancy Details] limits us to one CTA per SM (streaming multi-processor [NVIDIA GPU Details]), that still results in only 56 out of the available 148 SMs being used. A second problem also presents itself in the other problem shapes: tail effect [Tail Effect Details]. One approach that can help alleviate both of these issues is splitting the k-dimension into sub-blocks to be computed in parallel. This technique will be used for other kernels later on, and is often referred to as "split-k" [Split-K Details]. Parallelizing the k-dimension means we incur an overhead cost of reducing the partial results to obtain the final result; however, it also increases the amount of independent work that can be done in parallel and reduces the tail effect (the tail effect becomes more impactful for problem shapes that result in few and large GPU waves as opposed to a large number of GPU waves doing less work per wave).
V2
In order to split up the k-dimension we have each CTA handle a single result in the output, while each thread within that CTA computes a sub-block of the result along the k-dimension. Once all threads have computed their own partial result, those partial results get reduced into one final result that the CTA stores back to GMEM. We use a buffer in shared memory (SMEM) [Shared Memory Details] where each thread owns one 32b float element in that buffer. The reduction from partial to final results is done by iteratively halving the number of threads doing work, and at each iteration a thread sums it's own index in the SMEM buffer with it's index plus the number of remaining working threads. This makes the reduction time logarithmic in theory (instead of one master thread serially adding all other partial results in the CTA to it's own result).
// Once all partial sums are computed, reduce along the k-dimension
// Reduction in shared memory (generic for any block size)
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
if (l_tid < stride) {
partial_sums[l_tid] += partial_sums[l_tid + stride];
}
__syncthreads();
}
This kernel introduces another parameter which can be fine-tuned to achieve the best results: the size of the chunk of the k-dimension each thread computes. For this benchmark we chose that k-block size to be 32.
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 163.1us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 316.6us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 82.8us
As we expected the hardware is utilized far more effectively for the problem shape of the first benchmark, and reduced the runtime by about 80%. We also see significant improvements on our other two benchmarks. We see the least improvment with the middle benchmark, which we might expect since that problem has the largest number of outputs which would mean it was utilizing the hardware the best for our kernel in V1, thus resulting in the least improvement from the split-K change which aimed to better utilize the hardware by parallelizing across the k-dimension. Now our kernel performs better than the pytorch baseline for benchmark 3, about the same for benchmark 2 and still worse for benchmark 1.
At the moment if our kernel launches multiple warps [Warp Details] in a single CTA the threads within those warps will still perform the reduction step by accessing the partial results in SMEM; however, there is a faster way for warps to communicate data to each other when they reside within the same warp: warp shuffling [Warp Shuffling Details]. Using a hybrid approach of each warp reducing the constituent threads with warp shuffling then using SMEM to reduce across warps should give us a decent boost in performance since it reduces the amount of SMEM accesses and CTA wide synchronizations necessary before the result is stored to GMEM. We implement this in V3 below.
V3
Warp shuffling
// shuffle down thread_total in each warp
int warp_l_tid = l_tid % 32;
for (int offset = 16; offset > 0; offset >>= 1) {
thread_total += __shfl_down_sync(0xffffffff, thread_total, offset);
}
// store warp partial sum to SMEM
if (warp_l_tid == 0) {
partial_sums[l_wid] = thread_total;
}
__syncthreads();
// Once all partial sums are computed, reduce along the k-dimension
// Reduction in shared memory (generic for any block size)
int warps_in_cta = blockDim.x/32;
for (int stride = warps_in_cta / 2; stride > 0; stride >>= 1) {
if (l_wid < stride) {
partial_sums[l_wid] += partial_sums[l_wid + stride];
}
__syncthreads();
}
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 158.6us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 306.2us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 80.9us
We see a performance boost between 2us and 10us depending on the benchmark. The performance boost depends both on how many results are being computed and the size of the k-dimension. A larger number of results means potentially more waves launched by the kernel resulting in more serialized split-k reductions (assuming we are the only kernel running on the GPU and each wave fills the GPU as much as possible, each wave runs serially, so the reduction steps at the end for each wave are also serial). A larger k-dimension means the reduction step takes longer (more SMEM accesses, more CTA wide syncs). Thus, we expect the most improvement in problem shapes where there are more results being computed and K is larger. This explains the benchmark results we see: benchmark 1 sees ~5us likely because K is very large. Benchmark 2 benefits the most with ~10us of improvement due to a combination of the large number of results being computed and decently large K. Benchmark 3 sees the least improvement which makes sense because it computes fewer values than benchmark 2 and K is significantly smaller. Lastly, we can sanity check that this warp shuffle optimization is impacting all of these problem shapes in the same way because the runtime change by proportion is constant at ~3% improvement.
A 3% improvement would be great if we were already beating the baseline by a lot, but we are still lagging in the case of benchmark 1, which means there are other optimizations out there we are missing.
Often times in parallel computing algorithms optimizations can be revealed by looking at repeated operations (either memory or compute ops). Looking at our kernels up to this point we see that every single CTA is loading and type converting the same values in both B and SFB. L1/L2 caching helps prevent all of these repetitive memory loads from going to expensive global memory [General NVIDIA GPU Memory Architecture (Blackwell) Details], but it is still disorganized, redundant memory traffic that could be reduced or better organized to promote more effective caching and broadcasting [Memory Broadcasting Details]. Additionally, converting B and SFB into higher precision floats to compute with gets repeated by all CTAs. Sometimes repetitive computations are a necessary sacrifice to achieve better parallelism on HPC architectures (i.e. the cost of repetition is worth paying for the profit of parallelism); however, in kernel V4 we explore an algorithmic change that allows us to retain most of the parallelism we've established so far, while also reducing the amount of these repetitive operations on B/SFB.
V4
Instead of each CTA being responsible for a single output and each thread handling a chunk along the k-dimension, we have each CTA compute a chunk of the results and each warp within each CTA handles a sub-chunk of those results. This introduces two new parameters to our kernel: RESULTS_PER_WARP and NUM_WARPS_PER_CTA, whose product define how many results are computed per CTA. Each warp iterates along the k-dimension in blocks, loading and converting chunks of B/SFB and multiplying those same scaled B values by different rows in the A/SFA matrices. This way we reduce the amount of repetitive loads/conversions of B/SFB by a factor of RESULTS_PER_WARP since those values get reused for RESULTS_PER_WARP different output results. We also preserve some level of k-dimension parallelism because each thread in each warp is computing a separate sub-chunk along K for each chunk of K that warp handles.
Experimenting with different values of RESULTS_PER_WARP and NUM_WARPS_PER_CTA I found 4 and 8 respectively to be the fastest combination for our benchmarks. This is because increasing those numbers too much reduces our parallelism and ability to fully utilize the GPU because we launch fewer CTAs (each CTA handles more results => fewer launched CTAs). Additionally, increasing RESULTS_PER_WARP increases register pressure causing reduced occupancy and potential register spilling [Register Spilling Details].
This code snippet shows how we load and convert B/SFB before entering into a sub-loop to compute and accumulate for multiple end results:
for (int kb = 0; kb < k_blocks; kb++) {
// Check if this thread's k-range is valid
// bool valid = (kb * k_block_size + l_wtid * 16) < k;
// Load 16 values from b into thread RF using 1 uint64_t (8B) load
uint64_t b_vals_packed = reinterpret_cast<uint64_t const*>(b_ref)[(batch*(k/2)*128 + kb*(k_block_size/2) + l_wtid*8)/8];
__nv_fp4x2_storage_t const* b_vals = reinterpret_cast<__nv_fp4x2_storage_t const*>(&b_vals_packed);
// Load sfb value used for this thread
__nv_fp8_storage_t sfb = sfb_ref[batch*(k/16)*128 + kb*(32) + l_wtid];
__half_raw sfb_raw_fp16 = __nv_cvt_fp8_to_halfraw(sfb, __NV_E4M3);
__half sfb_fp16 = *reinterpret_cast<__half*>(&sfb_raw_fp16);
float sfb_fp32 = __half2float(sfb_fp16);
#pragma unroll
for (int result = 0; result < RESULTS_PER_WARP; result++) {...
Pseudo-code:
for kb in k_blocks:
thread loads 16 FP4s from B (uint64_t load)
thread loads sfb scalar
for result in RESULTS_PER_WARP: // 4 output rows
thread loads 16 FP4s from A row
thread loads sfa scalar
computes dot product in FP16 with __hfma2
warp shuffle-reduce the dot product sum
lane 0 accumulates into results[result]
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 151.9us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 196.3us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 60.0us
Benchmarks 2 and 3 saw significant 25-30% speed-ups, while benchmark 1 sees a minor 5% speed-up. We should expect this because benchmarks 2 and 3 have a large number of results relative to the size of K as compared to benchmark 1: ratios of M * L / K for each shown below
Benchmark 1: 0.48 Benchmark 2: 4.57 Benchmark 3: 14
This means the trade-off of a slight reduction in k-dimension parallelism for a reduction in repeated B/SFB loads/conversions is far more beneficial for the latter two benchmarks. More results allows us to maintain a high level of GPU wide parallelism/utilization by still having a large number of CTAs while also benefiting from that reduced repetition. When K gets too large relative to the number of results being computed those benefits get drowned out by the loss in k-dimension parallelism. That problem shape also benefits less because the number of repeated loads of B/SFB depends on the number of results being computed (i.e. size of M and L), so a low value of M and L means there aren't that many repeated operations to reduce in the first place.
Looking back at this point it should have been clear that I was missing something major in terms of what was bottlenecking our kernel; however, for the next dozen or so kernel iterations I make smaller optimizations on top of this kernel in an experimental fashion. In the final kernel I discover what the major bottleneck (especially for benchmark 1) was at this point in development (although I discover the motivating factor behind the observation in kernel V6_2_4).
V5
In the previous kernel warp l_wid wrote its 4 output rows at positions:
cta_row_start + result * NUM_WARPS_PER_CTA + l_wid
These rows are strided by 8 apart. This meant each warp's A rows were non-contiguous in memory, causing L1 cache misses. In this kernel we index by
cta_row_start + RESULTS_PER_WARP * l_wid + result
Now each warp's 4 rows are contiguous in memory. Similarly, A loads now use l_wid * RESULTS_PER_WARP * k_elems stride, accessing a contiguous block of A. We Also pre-compute a_base, b_base, sfa_base, sfb_base outside the k-loop in case this efficiency wasn't captured by compiler optimizations.
This minor optimization results in a small few micro-seconds of improvement.
V6
Each thread in our kernel currently loads all values directly from GMEM. We have structured the kernel in such a way where loads from GMEM into registers should experience a large amount of coalescence; however, we can obtain more granular control over the optimization of memory transactions from GMEM by using shared memory. If we replace individual threads loading from GMEM with warp or CTA wide cooperative loads from GMEM into SMEM buffers we can better guarantee coalesced transactions from GMEM all the way into the register file and we can ensure all of our "hot" data (i.e. data we use repeatedly in inner loops before needing new data) stays "cache resident" (i.e. it doesn't get evicted like it might if we simply rely on L1 cache hardware protocols). In V6 we make this change by declaring SMEM buffers for A, B, SFA, and SFB like below:
__shared__ uint64_t a_smem[NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfa_smem[NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ uint64_t b_smem[K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfb_smem[K_BLOCK_SIZE/16];
K_BLOCK_SIZE determines how much data a warp loads and computes across per iteration along the k-dimension. This is another parameter we can adjust for optimal performance. A larger K_BLOCK_SIZE means larger SMEM buffers, and each stage of loading and computing will take longer, but there are fewer iterations along the k-dimension. A smaller K_BLOCK_SIZE means faster memory/compute stages but more iterations overall. In a later kernel we will explore how to potentially overlap these stages of memory and compute ops (this will become way more important in later kernels).
The major change in the kernel is all threads participate in a cooperative load to get all the needed data from GMEM -> SMEM prior to the computation loop:
// First all threads in the CTA cooperatively load the necessary chunks of data GMEM->SMEM for this k-block
// Load b_ref and sfb_ref for k-block, only first warp should issue these since all warps in a CTA share this block
if (l_wid == 0) {
b_smem[l_wtid] = reinterpret_cast<uint64_t const*>(b_ref)[(b_base + kb*(K_BLOCK_SIZE/2))/8 + l_wtid];
sfb_smem[l_wtid] = sfb_ref[sfb_base + kb*(32) + l_wtid];
}
// Load block from a_ref (size: NUM_WARPS_PER_CTA*RESULTS_PER_WARP rows each with K_BLOCK_SIZE NVPF4 elements)
for (int result = 0; result < RESULTS_PER_WARP; result++) {
a_smem[l_wid*RESULTS_PER_WARP + result][l_wtid] = reinterpret_cast<uint64_t const*>(a_ref)[(a_base + l_wid*RESULTS_PER_WARP*k_elems + result*k_elems + kb*(K_BLOCK_SIZE/2))/8 + l_wtid];
sfa_smem[l_wid*RESULTS_PER_WARP + result][l_wtid] = sfa_ref[sfa_base + l_wid*RESULTS_PER_WARP*k_scalars + result*k_scalars + kb*32 + l_wtid];
}
// Ensure all data for this k-block made it from GMEM to SMEM
__syncthreads();
Algorithmically our kernel otherwise remains the same. Unfortunately, this change to using SMEM buffers leads to a regression in kernel performance of about 20%.
V6_2
The reason for the slowdown comes from the overhead associated with the SMEM buffers. I took the following steps in this kernel variation to try and reduce and compensate for the overhead:
Pre-convert B to FP16 in registers before the inner result loop:
__half2 b_vals_cvt[8];
for (int i = 0; i < 8; i++)
b_vals_cvt[i] = *reinterpret_cast<__half2*>(&__nv_cvt_fp4x2_to_halfraw2(b_vals[i], __NV_E2M1));
Now B conversion runs once per thread per k-block instead of once per result per k-block (4× less conversion work for B).
Pointer arithmetic instead of offset recomputation: advances a_ref, b_ref, sfa_ref, sfb_ref by a stride each k-block iteration, eliminating the multiply-add offset
computations inside the loop.
Remove __syncwarp() before the warp shuffle — __shfl_down_sync itself synchronizes the warp, making it redundant.
Inner loop unrolled by 2 (i += 2) to allow the compiler to better interleave A conversion with the FMA instruction pipeline.
This minor adjustments bring our kernel to a few micro-second improvement on the V5 kernel:
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 138.5us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 189.3us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 58.4us
V6_2_1
Instead of performing the type conversions for B after the data has been loaded into SMEM, we first convert the data then store it in the SMEM buffer. This reduces the number of SMEM accesses as well as the number of warps that need to perform conversions. This minor change buys us 8-11% improvement.
// First all threads in the CTA cooperatively load the necessary chunks of data GMEM->SMEM for this k-block
// Load b_ref and sfb_ref for k-block, only first warp should issue these since all warps in a CTA share this block
if (l_wid == 0) {
b_smem[l_wtid] = reinterpret_cast<uint64_t const*>(b_ref)[l_wtid];
__nv_fp4x2_storage_t const* b_vals = reinterpret_cast<__nv_fp4x2_storage_t const*>(&b_smem[l_wtid]);
sfb_smem[l_wtid] = sfb_ref[l_wtid];
#pragma unroll
for (int i = 0; i < 8; i++) {
b_smem_fp16[l_wtid*8 + i] = *reinterpret_cast<__half2*>(&__nv_cvt_fp4x2_to_halfraw2(b_vals[i], __NV_E2M1));
}
}
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 127.9us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 169.1us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 52.3us
V6_2_2
In addition to converting B values to half precision during the load stage we also convert the scale factors. At this step I also introduce new compiler flags:
'-gencode=arch=compute_100a,code=sm_100a', '-Xptxas', '--allow-expensive-optimizations=true'
These compiler flags [NVCC Compiler Flags Details] in combination with the small algorithmic change resulted in a massive boost in performance. Most of this performance gain is due to the compiler flags which condensed much of the PTX and allowed any potential Blackwell specific optimizations to be made by the compiler (see the linked compiler details section for more info on what happened here). I wasn't used to working in this Python embedded environment so compiler flag options didn't occur to me until this point in the development process.
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 58.7us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 64.4us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 21.5us
V6_2_3
In this kernel we move the data conversion of SFA into the data load section. This resulted in negligible timing changes to the kernel. This is likely because the SFA data isn't shared, and thus we aren't really introducing any optimization via more parallelization or a reduction of work, we've simply moved where in the kernel that work gets performed.
V6_2_4
Reviewing analysis from Nsight Compute [Nsight Compute Details] I saw that 80% of my kernel was composed of Arithmetic instructions. This seemed odd since the computations in the kernel used floating point math. My first thought was that this was due to excessive pointer resolution arithmetic in the inner loops of the kernel that weren't being properly optimized by the compiler. I took the following steps in this kernel to reduce the number of LOP3/PRMT/IMAD instructions involved in address resolution:
Pre-converting A matrix to FP16 in shared memory (eliminates conversions in compute loop)
Pre-scaling B matrix with sfb (eliminates scale multiplication in compute loop)
Pre-computing combined scales as FP16 (reduces FP32 operations)
However, this didn't significantly impact the instruction composition of the kernel, and performance stayed about the same. The other source of arithmetic instructions was the conversion of the 4 bit and 8 bit floats to FP16; however, at this point I didn't think there was a way of avoiding this overhead. In the final kernel I discover a very useful PTX instruction that helps resolve this. In the next few kernels I try some tangential optimizations.
V6_3
Another insight derived from NCU (Nsight Compute) was that register usage was our occupancy limiter. Meaning we could potentially run more CTAs in parallel if we managed to reduce the number of registers used by each thread without demanding too much of any other resource (SMEM, threads, etc...) so as to equally or greater limit occupancy. In this kernel we achieve this by moving the results array out of registers into SMEM. We incur a cost of loading and storing back to SMEM during every accumulative update for each k-dimension iteration; however, we get better occupancy. Unfortunately the increase in occupancy doesn't get us enough of a performance boost to compensate for the loss induced the the SMEM cost, so performance degrades slightly across the board.
V7
This is the first kernel where I introduce a form of latency hiding [Latency Hiding Details]. In this version I attempt to overlap the load of the next k-block with the compute of the current k-block. In this case we don't use any form of asynchronous data movement on the processor. Instead we take advantage of the deep pipelines in the GPU architecture which allow us to dispatch loads to one SMEM buffer while executing computation on a separate SMEM buffer. Once the data gets large enough and the number of memory buffer stages increases (deeper software pipeline) using asynchronous mechanisms becomes advantageous, but this is a good first proof of concept that still allows for a decent amount of memory and computational overlap.
Double-buffered overlap diagram:
Time ─────────────────────────────────────────────────────────────────►
┌─────────────┬─────────────┬─────────────┬─────────────┐
buf[0] │ Load k-blk0 │ │ Load k-blk2 │ │ ...
│ (GMEM→SMEM) │ │ │ │
└─────────────┴─────────────┴─────────────┴─────────────┘
┌─────────────┬─────────────┬─────────────┐
buf[1] │ Load k-blk1 │ │ Load k-blk3 │ ...
│ │ │ │
└─────────────┴─────────────┴─────────────┘
┌─────────────┬─────────────┬─────────────┐
compute (idle, prime) │ Compute on │ Compute on │ Compute on │ ...
units │ buf[0] │ buf[1] │ buf[0] │
└─────────────┴─────────────┴─────────────┘
└── overlap ──┘└── overlap ─┘└── overlap ─┘
The load issued into one buffer runs in the GMEM→L2→SMEM pipeline at
the same time the compute units consume the other buffer; the two
share no resources so the durations add only on the critical path.
To implement this we simply use two buffers and alternate between which one is loading from GMEM and which one is being used for computation, as shown in the code snippets below.
// Double-buffered SMEM: [2 buffers][original layout]
__shared__ uint64_t a_smem[2][NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfa_smem[2][NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ uint64_t b_smem[2][K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfb_smem[2][K_BLOCK_SIZE/16];
Pseudo-code:
Prologue: load k-block 0 into buf[0]
Loop:
read_buf = write_buf; write_buf = 1 - write_buf
issue load for k-block kb+1 into buf[write_buf] // async
compute on buf[read_buf] // overlaps with load
__syncthreads()
int read_buf = write_buf;
write_buf = 1 - write_buf; // Toggle buffer
// Launch async load for next k-block (if not last iteration)
if (kb + 1 < k_blocks) {
if (l_wid == 0) {
b_smem[write_buf][l_wtid] = reinterpret_cast<uint64_t const*>(b_ref)[(b_base + (kb+1)*(K_BLOCK_SIZE/2))/8 + l_wtid];
sfb_smem[write_buf][l_wtid] = sfb_ref[sfb_base + (kb+1)*32 + l_wtid];
}
for (int result = 0; result < RESULTS_PER_WARP; result++) {
a_smem[write_buf][l_wid*RESULTS_PER_WARP + result][l_wtid] =
reinterpret_cast<uint64_t const*>(a_ref)[(a_base + l_wid*RESULTS_PER_WARP*k_elems + result*k_elems + (kb+1)*(K_BLOCK_SIZE/2))/8 + l_wtid];
sfa_smem[write_buf][l_wid*RESULTS_PER_WARP + result][l_wtid] =
sfa_ref[sfa_base + l_wid*RESULTS_PER_WARP*k_scalars + result*k_scalars + (kb+1)*32 + l_wtid];
}
}
This results in a pretty significant performance boost with larger K resulting in more improvement. Benchmark 1 sees about 20% improvement over V6_2_2, benchmark 2 sees ~15% improvement, and benchmark 3 about 10%. A larger K results in more performance gain because larger K means more iterations along the k-dimension which means there are more overlapping memory and compute phases. Put plainly if every time we overlapped a memory load and compute phase we saved x amount of time (because without overlap they would be serialized), then larger K means a larger multiplier on that x amount of time saved, resulting in a larger percentage performance gain.
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 46.5us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 54.4us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 19.6us
V7_2
In the previous kernel our software pipelining [Software Pipelining Details] relies on data dependency handling in the hardware execution pipelines. NVIDIA GPUs also offer explicit asynchronous mechanisms of data transfer that software can manage so you can more effectively pipeline larger amounts of data movement and computation and explicitly ensure overlap of the software pipeline stages. This strategy will become more prevalent in later kernels, and will be discussed in greater detail below. In this kernel we implement a very simple pipeline scheme using cuda::pipeline and cuda::memcpy_async from the cuda/pipeline library.
The new load block now looks like the below:
if (kb + 1 < k_blocks) {
if (l_wid == 0) {
cuda::memcpy_async(&b_smem[write_buf][l_wtid],
&reinterpret_cast<uint64_t const*>(b_ref)[(b_base + (kb+1)*(K_BLOCK_SIZE/2))/8 + l_wtid],
sizeof(uint64_t), pipe);
cuda::memcpy_async(&sfb_smem[write_buf][l_wtid],
&sfb_ref[sfb_base + (kb+1)*32 + l_wtid],
sizeof(__nv_fp8_storage_t), pipe);
}
for (int result = 0; result < RESULTS_PER_WARP; result++) {
cuda::memcpy_async(&a_smem[write_buf][l_wid*RESULTS_PER_WARP + result][l_wtid],
&reinterpret_cast<uint64_t const*>(a_ref)[(a_base + l_wid*RESULTS_PER_WARP*k_elems + result*k_elems + (kb+1)*(K_BLOCK_SIZE/2))/8 + l_wtid],
sizeof(uint64_t), pipe);
cuda::memcpy_async(&sfa_smem[write_buf][l_wid*RESULTS_PER_WARP + result][l_wtid],
&sfa_ref[sfa_base + l_wid*RESULTS_PER_WARP*k_scalars + result*k_scalars + (kb+1)*32 + l_wtid],
sizeof(__nv_fp8_storage_t), pipe);
}
pipe.producer_commit();
}
Under the hood cuda::memcpy_async uses the basic cp.async PTX instructions; cuda::pipeline uses synchronization mechanisms like mbarriers (also discussed in great detail later on) to ensure completion of data transfers before threads can use that data for computation. Functionally this kernel is identical to V7, but it uses a different hardware mechanism to accomplish the decoupling of memory loading and computation on separate buffers in SMEM. It thus makes sense that this kernel is slightly slower than the previous kernel because there is a small amount of overhead associated with the setup of these asynchronous structures and in the underlying synchronization mechanisms in addition to the fact that we don't obtain that much more overlap with only two available buffers to overlap in our software pipeline.
V7_3
If two buffers wasn't enough to fully take advantage of the cp.async mechanism we expand the alternating or "ping-pong"-ing between two buffers out to something called a ring buffer. A ring buffer or circular buffer is just a fixed size array with a pointer to an element in that array. Once you've finished with an element the pointer is advanced and if the pointer reaches the end of the array it resets to the beginning, overwriting what was previously there.
Ring buffer (NUM_STAGES = 5):
┌────────┐
│ stg 0 │ ← compute_ptr (current MMA)
├────────┤
│ stg 1 │ ← in flight (cp.async)
├────────┤
│ stg 2 │ ← in flight (cp.async)
├────────┤
│ stg 3 │ ← in flight (cp.async)
├────────┤
│ stg 4 │ ← load_ptr (next async load)
└────────┘
indices mod NUM_STAGES (wrap to 0 after 4)
Prologue: fill stages 0..N-2 with async loads (no compute yet).
Steady state: compute consumes stage k while a new async load is
issued into stage (k + N - 1) mod N. The deeper the
ring, the more outstanding loads can be in flight in
the L2/HBM pipeline simultaneously, hiding more
memory latency — at the cost of SMEM occupancy.
The elements in our size N ring buffer are going to be N A/SFA/B/SFB buffers in SMEM. This allows our kernel to race ahead and issue loads to fill all 5 of these buffers with data. Once those async loads have been issued our computation stage can start computing on the buffers in each part of our ring buffer as the data becomes available. Giving the loads a bigger head start allows for more overlap of the async loads with the computation operations.
// Multi-stage buffered SMEM
__shared__ uint64_t a_smem[NUM_STAGES][NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfa_smem[NUM_STAGES][NUM_WARPS_PER_CTA*RESULTS_PER_WARP][K_BLOCK_SIZE/16];
__shared__ uint64_t b_smem[NUM_STAGES][K_BLOCK_SIZE/16];
__shared__ __nv_fp8_storage_t sfb_smem[NUM_STAGES][K_BLOCK_SIZE/16];
// Prologue: Fill the pipeline with NUM_STAGES-1 blocks
for (int kb = 0; kb < NUM_STAGES - 1 && kb < k_blocks; kb++) {
int stage = kb % NUM_STAGES;
if (l_wid == 0) {
cuda::memcpy_async(&b_smem[stage][l_wtid],
&reinterpret_cast<uint64_t const*>(b_ref)[(b_base + kb*(K_BLOCK_SIZE/2))/8 + l_wtid],
sizeof(uint64_t), pipe);
cuda::memcpy_async(&sfb_smem[stage][l_wtid],
&sfb_ref[sfb_base + kb*32 + l_wtid],
sizeof(__nv_fp8_storage_t), pipe);
}
for (int result = 0; result < RESULTS_PER_WARP; result++) {
cuda::memcpy_async(&a_smem[stage][l_wid*RESULTS_PER_WARP + result][l_wtid],
&reinterpret_cast<uint64_t const*>(a_ref)[(a_base + l_wid*RESULTS_PER_WARP*k_elems + result*k_elems + kb*(K_BLOCK_SIZE/2))/8 + l_wtid],
sizeof(uint64_t), pipe);
cuda::memcpy_async(&sfa_smem[stage][l_wid*RESULTS_PER_WARP + result][l_wtid],
&sfa_ref[sfa_base + l_wid*RESULTS_PER_WARP*k_scalars + result*k_scalars + kb*32 + l_wtid],
sizeof(__nv_fp8_storage_t), pipe);
}
pipe.producer_commit();
}
Unfortunately we don't see much improvement over V7, and we actually see a little bit of performance regression in benchmark 3 because K is so small. We also have to be careful with the depth of our software pipeline because more stages demands more SMEM which could potentially reduce occupancy and increase runtime. This trade-off will also be discussed in much more detail in later kernels.
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 42.2us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 52.8us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 21.5us (regression)
Final Kernel
The first change in this kernel is a minor one from an algorithmic perspective. Instead of using the cuda::memcpy_async API we switch to using the cde::cp_async_bulk_tensor_2d_global_to_shared API to access the PTX instructions cp.async.bulk which allow for two dimensional data asynchronous data transfers with a single instruction. These instructions require the use of something called a TMA Descriptor [TMA Details]. This gives us a minor performance boost of a few micro-seconds, dampened by the overhead of TMA descriptor setup. The benefit comes from not having to loop over the rows of the A matrix and SFA matrix to issue multiple cp.async loads, instead we can issue a single cp.async.bulk load to get multiple rows from A and SFA as shown below:
if (kb+1 < k_blocks) {
// thread 0 for each CTA issues all the async TMA transfers
if (threadIdx.x == 0) {
// Load a block
cde::cp_async_bulk_tensor_2d_global_to_shared(&a_smem[load_stage], &tensor_map_a, (kb+1)*(K_BLOCK_SIZE/8), cta_row, bar[load_stage]);
// Load sfa block
cde::cp_async_bulk_tensor_2d_global_to_shared(&sfa_smem[load_stage], &tensor_map_sfa, (kb+1)*(K_BLOCK_SIZE/32), cta_row, bar[load_stage]);
// Load b block
cde::cp_async_bulk_tensor_2d_global_to_shared(&b_smem[load_stage], &tensor_map_b, (kb+1)*(K_BLOCK_SIZE/8), batch, bar[load_stage]);
// Load sfb block
cde::cp_async_bulk_tensor_2d_global_to_shared(&sfb_smem[load_stage], &tensor_map_sfb, (kb+1)*(K_BLOCK_SIZE/32), batch, bar[load_stage]);
token[load_stage] = cuda::device::barrier_arrive_tx(bar[load_stage], 1, total_smem_size);
} else {
token[load_stage] = bar[load_stage].arrive();
}
}
The second and most impactful change in this kernel is a rather simple one. In all previous kernels we relied on CUDA instrinsics to perform the conversion of low-bit data types to FP16 for further computation. As discussed earlier these intrinsics compiled down to millions of arithmetic instructions, drowning our kernel in arithmetic runtime. Luckily PTX offers a way around this via the below instructions:
cvt.rn.f16x2.e2m1x2 (FP4→FP16), cvt.rn.f16x2.e4m3x2 (FP8→FP16)
By replacing the CUDA conversion instrinsics with these instructions we significantly reduce the number of arithmetic instructions required to perform the conversion, thus making that conversion significantly faster when measured by the number of cycles it would take to complete. This change could be observed both in the raw PTX/SASS and in the NCU instruction breakdown of the kernel. The result is a significant 30-50% speedup in latency!
We also use inline PTX to take advantage of the below instructions which operate on pairs of FP16 number in parallel:
fma.rn.f16x2, mul.rn.f16x2, add.rn.f16x2
The compiler was using this to some extent prior to us using inline PTX so the performance gain isn't very significant, but it helps maintain any performance gain we would have lost if we had stopped using these instructions when we switched to inline PTX. The new raw PTX loop is too large to include here, but can be viewed in the source code.
There are also two variations of kernels available to run depending on the size of K. One kernel supports a double-buffering scheme introduced in kernel V7 (more sophisticated software pipelining schemes as implemented in V7_3 turned out not to be as fruitful), the other doesn't support software pipelining as the setup, synchronization, and SMEM occupancy costs can outweight any benefit if K is small enough relative to M*L. At runtime we dynamically choose which kernel to run depending on the size of K as shown below:
if k < 16384 and k % 1024 == 0:
return kernel_small_k(a_ref, b_ref, sfa_ref, sfb_ref, c_ref, a_ref.shape[0], a_ref.shape[1]*2, a_ref.shape[2])
else:
return kernel(a_ref, b_ref, sfa_ref, sfb_ref, c_ref, a_ref.shape[0], a_ref.shape[1]*2, a_ref.shape[2])
We ended up improving our kernel from the baseline naive PyTorch implementation benchmarks:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 82.8us
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 327.6us
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 131.5us
To our final kernel benchmarks:
Benchmark results:
{'k': 16384, 'm': 7168, 'l': 1} -> Mean: 23.6us (72.5% improvement)
{'k': 7168, 'm': 4096, 'l': 8} -> Mean: 35.9us (89.1% improvement)
{'k': 2048, 'm': 7168, 'l': 4} -> Mean: 15.4us (88.3% improvement)
GEMV Summary
Optimization Summary:
┌──────────────┬───────────────────────────────────────────┬────────────────────────────────────────────────────────────────┐ │ Version │ Technique │ Primary Gain │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v0 │ PyTorch baseline │ Correctness baseline │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v1 │ CUDA kernel │ Single Kernel Launch │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v2 │ CTA-level K parallelism + SMEM reduction │ K-dim parallelism │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v3 │ Warp shuffle reduction │ Faster intra-warp reduction │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v4 │ Warp output tiling (4 rows/warp) │ B memory reuse 4× │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v5 │ Contiguous warp row assignment │ L1 cache locality │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v6 │ GMEM→SMEM staging │ Coalesced loads │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v6_2 │ Pre-convert B, pointer arith, no syncwarp │ Less redundant work │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v6_2_x │ SF conversion hoisting, pre-scaling │ Fewer conversions in compute loop │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v7 │ Double-buffered SMEM │ Memory latency hiding │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v7_2 │ cuda::pipeline async copies │ Hardware async path │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ v7_3 │ 5-stage pipeline │ Deeper prefetch │ ├──────────────┼───────────────────────────────────────────┼────────────────────────────────────────────────────────────────┤ │ submit_final │ TMA + PTX asm + dual dispatch │ Dedicated HW units, optimal ILP, reduced arithmetic load │ └──────────────┴───────────────────────────────────────────┴────────────────────────────────────────────────────────────────┘
Broad Lessons:
Examining aspect and dimensional ratios (like we did in this case looking at (M * L) / K) of problem shapes can be very revealing as to what optimization techniques will be most effective. In this case that particular ratio told us that a smaller value meant optimizations that help process along the k-dimension faster will be more beneficial because a larger K means more serialized work per CTA. A larger value for that ratio means parallelism across separate result computations should be prioritized of k-dimensional parallelism since the number of results being computed dominates memory and compute resources and ultimately runtime as opposed to the serial work along the k-dimension.
Understanding the theoretical performance limiter for a given operation, problem shape, and hardware combination can help reveal the major bottlenecks of a working kernel, especially when the limiter in practice is not what theory predicts. For example in this matrix-vector operation the theoretical performance limiter is memory latency. Matrix-vector multiplication has low arithmetic intensity, which refers to how many compute operations we can perform per unit of memory transfer. In other words an increase in the number of computations for which a single input value can participate, increases our arithmetic intensity. Thus a low arithmetic intensity means memory operations (which usually take longer than compute operations) will likely dominate the runtime in a perfectly optimized kernel. Realizing this earlier in the kernel development process could have allowed me to discover the type conversion operator in PTX earlier than I did.
Useful Links:
[1] https://medium.com/@snshyam/cuda-deep-dive-what-happens-when-you-launch-a-kernel-034e23624932