GEMM
In this section I cover the basics of using tcgen05 instructions to accelerate matrix multiplication on the NVIDIA Blackwell architecture. I won't be covering iterative improvements as I couldn't compete in this segment of the competition. However, further GEMM optimizations applicable to basic GEMM and more will be explored in the following two kernels, which are both GEMM variants.
GEMM Operation
GEMM (GEneral Matrix Multiply) performs the operation D <- a * A @ B + b * C, argument details:
D := Resultant MxN matrix a := Scalar value A := Input matrix MxK B := Input matrix KxN b := Scalar value C := Input MxN matrix
C and D can refer to the same matrix. The particular variant of matrix multiply we deal with in the following kernels is the simplified operation C <- A @ B, so there are no scalar values involved, and there is no summation onto an existing matrix.
Pre-Tensor Core GEMM
Prior to the inclusion of tensor cores on the Volta architecture, CUDA cores were used to perform matrix multiplication. Many of the broader optimization principles used in the pre-tensor core algorithms are similar to those used in the post-tensor core algorithms, but there are also many aspects that are distinct as a result of the paradigm introduced by the use of tensor cores.
This blog provides great insight into writing optimal GEMM kernels without tensor cores: https://siboehm.com/articles/22/CUDA-MMM
Principles that will carry through to the tensor core versions:
Fully utilizing memory bandwidth
Maximizing temporal and spatial locality (thereby increasing arithmetic intensity)
Tensor Cores
From a software perspective the primary function of tensor cores is to multiply two MxK and KxN chunks of matrices and store/accumulate that result into an MxN buffer. For a more detailed description of the evolution of tensor cores see this article by Modular [1]. Tensor cores can source the input data from SMEM or TMEM (Tensor Memory), depending on the input. Tensor Memory is a 256KB storage bank per SM dedicated to tensor core operations [2]. The details of how tensor core operations are managed and interact with threads and warps has changed over the various NVIDIA architectures since Volta; we will be discussing the Blackwell architecture.
Blackwell Tensor Cores (tcgen05) and GEMM
Central to our kernel is the tcgen05 mma instruction. In particular the block-scaled floating point version using nvfp4 values with e4m3 scalars. There are two variants we can use, the only difference between the two is where the A-block is stored (TMEM or SMEM). For this implementation we choose the SMEM variant:
tcgen05.mma.cta_group.kind.block_scale{.scale_vectorsize}
[d-tmem], a-desc, b-desc, idesc,
[scale-A-tmem], [scale-B-tmem], enable-input-d;
.kind = { .kind::mxf8f6f4, .kind::mxf4, .kind::mxf4nvf4 }
.cta_group = { .cta_group::1, .cta_group::2 }
.scale_vectorsize = { .scale_vec::1X, .scale_vec::2X, .scale_vec::4X, .block16, .block32 }
PTX Docs: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-mma-instructions
This instruction, issued by a single thread in a CTA, takes in four essential data arguments:
MxK nvfp4 data block in SMEM (call this A)
Mx(K // 16) e4m3 scale factor block in TMEM, 16 comes from 16 adjacent elements being scaled by one scale factor (call this SFA)
KxN nvfp4 data block in SMEM (call this B)
(K // 16)xN e4m3 scale factor block in TMEM (call this SFB)
Then, all within the same instruction dispatch, the nvfp4 data blocks (A and B) are scaled into higher precision floats by the scale factor matrices (SFA and SFB), those resultant scaled matrices are then multiplied and stored or accumulated into a buffer in TMEM of size MxN.
As can be seen from the instruction format of tcgen05.mma the inputs I list above don't line up with the arguments to that instruction precisely. Before diving into how we provide the necessary inputs via the arguments of the tcgen05.mma instruction it helps to go over the broader MMA algorithm first.
Let each mma instruction compute an mBxnB segment of C with a depth of kB (so each instruction computes the block-scaled matmul of an mBxkB segment of A with a kBxnB segment of B). The execution of a single one of these mma ops requires us to:
Load mBxkB block of A and kBxnB block of B into SMEM
Load mBx(kB/16) block of SFA and (kB/16)xnB block of SFB into TMEM
Execute mma and accumulate into resultant mBxnB region in TMEM
For a single CTA this is the algorithm flow:
Assume we have allocated SMEM with enough room for the following:
A-block (mBxkB of nvfp4)
B-block (kBxnB of nvfp4)
SFA-block (mBx(kB/16) of e4m3)
SFB-block ((kB/16)xnB of e4m3)
Allocate TMEM (need room for SFA-block, SFB-block, C-block)
For each k-block (i.e. for every kB sized block along the k-dimension):
Load SFA-block, SFB-block into SMEM
Load A-block, B-block into SMEM
Copy SFA-block, SFB-block into TMEM
Execute mma
Move C-block into CTA regs
Move C-block from CTA regs to GMEM
Deallocate TMEM
There are many improvements that can be made to this algorithm that are discussed in detail later on, but this simple outline is enough to highlight the core PTX instructions necessary and the intricacies involved therein.
Rewinding a few paragraphs I highlighted four primary inputs to the tcgen05.mma instruction: A, B, SFA, SFB. We specify this inputs as follows in the instruction itself:
a-desc specifies A
b-desc specifies B
scale-A-tmem specifies SFA
scale-B-tmem specifies SFB
However, it isn't as simple as just loading the data "as is" into SMEM/TMEM buffers and providing the instruction pointers to those buffers. The data has to be structured in particular formats according to how the instruction has been configured (this includes things like what we choose for the size of the data blocks (mB, nB, kB), swizzling mode (Swizzling Details), etc...). In my opinion this is the trickiest part of using the tcgen05.mma instructions because it's a seemingly unnecessary complexity (likely because the formatting constraints are a hardware design artifact that we aren't shielded from when programming this far down in the software stack, CuteDSL/CUTLASS/Triton et al. handle this kind of thing for you, albeit in exchange for less granular control and visibility into how the hardware is being programmed), and more importantly these formatting intricacies are very poorly explained in the PTX documentation.
This next section explains how this data formatting works for the block-scaled tcgen05.mma instruction.
Data formatting:
First let's see how to format the A and B blocks for a single MMA instruction:
We can see from the mma instruction format that the location and format of the A/B blocks is determined by the a-desc/b-desc arguments to the mma instruction. These are "matrix descriptors", which have the following format and describe how the tile is layed out in SMEM:
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-descriptors
[0-13] : encode( matrix start address )
[16-29] : encode( LBO ) (Leading Byte Offset)
[32-45] : encode( SBO ) (Stride Byte Offset)
[46-48] : 0b001 (fixed constant)
[49-51] : Matrix base offset
[52] : 0: byte offset relative ( other option is - 1: byte address absolute, not discussed yet here)
[53-60] : 0xb00000000 (fixed constant)
[61-63] : Swizzle (0. No swizzling 1. 128-Byte with 32B atomic swizzling 2. 128-Byte swizzling 4. 64-Byte swizzling 6. 32-Byte swizzling)
* encode(x) = (x & 0x3FFFF) >> 4 (keep the first 14 bits, then divide by 16)
In depth explanation for "LBO" and "SBO":
Before diving into what each field means we have to discuss the concept of a "core matrix". This isn't discussed anywhere in official NVIDIA docs as far as I can tell at the timing of writing, and (likely for that reason) LLMs hallucinate incorrect explanations for this concept. For some reason it's something that one needs to infer based on how WGMMA instructions worked on Hopper and the requirements of tcgen05 MMA instructions.
A core matrix is an 8xN byte tile, where N represents the number of bytes in the swizzle mode (Swizzling Details). If the specified swizzle mode is "None" then N is assumed to be 16. The size 8 dimension is taken along the M/N-dimension, and the N byte chunks are taken along the K-dimension. So if our matrix is stored in K-major format in GMEM, and we want to load it into SMEM for the "None" swizzle mode (meaning no shuffling of data to avoid bank conflicts), our core matrix would look like an 8x16B block. For K-major we store columns of core matrices contiguously, so if our MMA instruction ingests a 128x64 tile of A (NVFP4) (which can be broken down into 16 rows x 2 cols of core matrices), in shared memory we would store the 16 core matrices in the first column of core matrices contgiously followed by the 16 in the second column (with the 16B rows within each core matrix also being contiguous).
Core matrix diagram (no-swizzle case, N = 16 bytes):
A single core matrix is 8 rows × 16 bytes laid out contiguously:
<---------- 16 bytes ----------> ┐
┌────┬────┬────┬─ ... ─┬────┬────┐ │
row 0 │ b0 │ b1 │ b2 │ │b14 │b15 │ │
├────┼────┼────┼─ ... ─┼────┼────┤ │
row 1 │b16 │b17 │b18 │ │b30 │b31 │ │
├────┼────┼────┼─ ... ─┼────┼────┤ │
row 2 │b32 │ ... │ 8 rows
├─────────────────────────────────┤ │
... │ │
├─────────────────────────────────┤ │
row 7 │b112│ ... │b127│ │
└────┴────┴────┴─ ... ─┴────┴────┘ ┘
For a K-major 128 × 64 NVFP4 MMA tile (32 bytes wide in K = 2 core
matrix columns; 128 rows = 16 core matrix rows), SMEM layout is:
column 0 of CMs (K=[0,32)) column 1 of CMs (K=[32,64))
┌──────────┐ ┌──────────┐
│ CM[0,0] │ ← address X │ CM[0,1] │ ← address X + LBO
├──────────┤ ├──────────┤
│ CM[1,0] │ ← X + SBO │ CM[1,1] │
├──────────┤ ├──────────┤
│ CM[2,0] │ ← X + 2*SBO │ CM[2,1] │
├──────────┤ ├──────────┤
│ ... │ │ ... │
├──────────┤ ├──────────┤
│ CM[15,0] │ │ CM[15,1] │
└──────────┘ └──────────┘
contiguous in SMEM then next column
SBO = 8 × 16 = 128 B (stride to next CM in same column)
LBO = MN_DIM × 16 = 128 × 16 = 2048 B (stride to next column of CMs)
The matrix descriptor's "matrix start address" points at CM[0,0],
and LBO/SBO describe how to walk the rest of the tile.
This blog by Modular has an explanation that is pretty good, and where I sourced a large part of my understanding from: https://www.modular.com/blog/matrix-multiplication-on-nvidias-blackwell-part-2-using-hardware-features-to-optimize-matmul
This table (38) in the PTX docs describes the structure of a core matrix based on leading dimension and swizzle mode: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-swizzle-lead-dim This table then describes canonical layouts (essentially saying the same thing), but it's a bit convoluted by the CuTe layouts: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-canonical-layouts
Now LBO and SBO are much easier to understand: LBO is the number of bytes needed to reach the next column of core matrices (so, in our example above that would be 128*16B), hence the name "Leading" which is typically used to refer to the amount to stride to reach the start of another dimension. SBO refers to the number of bytes needed to reach the next core matrix within a column (which be 8*N bytes where N is the bytes in the swizzle mode (again "None" corresponds to 16)).
"matrix start address" - This is the address of the start of the tile we are using for the MMA "LBO" - Explained above "SBO" - Explained above, together LBO and SBO enable us to describe the structure of the tile in SMEM in a way hardware can ingest (i.e. using core matrices) "Matrix base offset" - 0 unless using a swizzling algorithm whose pattern doesn't start on a "canonical" boundary (not exactly sure on the details of how/why to use this) "Swizzle" - Which swizzle mode to use
Now let's look at how to store the tiles of SFA and SFB which are used to scale the tiles of A and B. The scale factor tiles must reside in TMEM to be used by the mma instruction. TMEM is structured as 128 rows (called lanes) and 512 columns where each row/column element is a 32b (4B) unit (TMEM Details). The layout the mma instruction expects depends on the dimensions of the mma operation, but generally for 1B scale factors the m/n dimension is split up into chunks of size 32 and stacked next to each other. For example, if the SFA tile is 128x4 (corresponding to 128x64 NVFP4 A tile) then there would be 4 32x4 chunks layed out next each other in TMEM (occupying 32 lanes and 4 columns (each column holds 4 scale factors)). In the language of CuTe layouts this can be viewed as 32x4x4. The 32 row (lane) chunks get repeated 4 times to occupy the full columns. This will become more clear further down when we discuss copying the scale factor data from SMEM to TMEM. I suspect the reason for the repitition is due to a need by the hardware to access the same data for all 4 of the 32 lane partitions of a column. Similar to core matrices this is another hardware complexity that isn't abstracted away at this level in the software stack.
[Suggested in-line diagram — the diagrams at https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-block-scaling also demonstrate what I describe here, specifically Figure 233.]
SFA layout in TMEM for a 128 × 64 NVFP4 A tile (4 SFs per row, M = 128):
col 0 col 1 col 2 col 3
┌───────┬───────┬───────┬───────┐
lane 0 │ sf[ 0]│ sf[ 0]│ sf[ 0]│ sf[ 0]│ ← row 0 of A
lane 1 │ sf[ 1]│ sf[ 1]│ sf[ 1]│ sf[ 1]│ ← row 1 of A
lane 2 │ sf[ 2]│ sf[ 2]│ sf[ 2]│ sf[ 2]│ ← row 2 of A
... │ ... │ ...
lane 31 │ sf[31]│ sf[31]│ sf[31]│ sf[31]│ ← row 31 of A
├───────┼───────┼───────┼───────┤
lane 32 │ sf[32]│ sf[32]│ sf[32]│ sf[32]│ ← row 32 of A
... │ (next 32-row chunk, columns repeat the same way)│
lane 63 │ sf[63]│ sf[63]│ sf[63]│ sf[63]│
├───────┼───────┼───────┼───────┤
lane 64-95│ (next 32-row chunk) │
├───────┼───────┼───────┼───────┤
lane 96-127│ (last 32-row chunk) │
└───────────────────────────────┘
CuTe-style: 32 × 4 × 4 — i.e. four 32-lane × 4-col sub-tiles stacked
along M, with each sub-tile's payload also broadcast 4× across cols
(the column repetition is the "multicast" referred to by
`tcgen05.cp ... .warpx4` on the SMEM→TMEM copy).
Next we look at how the result of the mma op is stored in TMEM. There are a number of different formats depending on the configuration of the executed mma instruction, the details of those formats are outlined in this table:
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-organization
These layouts describe how the output data is stored or accumulated into the TMEM buffer passed to the mma instruction via the d-tmem argument.
Many options in that table likely don't make much sense at the moment; however, the most important options are M and cta_group. M is just the size of M being computed by the mma instruction (which computes an MxN block from the MxK and KxN chunks from A and B in SMEM). Refer to CTA_GROUP Details for more info on this hardware feature which will be used in later kernels. For now we use cta_group == 1 which leaves us with Layout D:
[Suggested in-line diagram — the canonical NVIDIA figure is at https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d]
TMEM Layout D for an M × N output tile (M = 128, N = 64 example):
col 0 col 1 col 2 ... col 63
┌────────┬────────┬────────┬────────┐
lane 0 │ C[ 0,0]│ C[ 0,1]│ C[ 0,2]│ C[ 0,63]│ ┐
lane 1 │ C[ 1,0]│ C[ 1,1]│ C[ 1,2]│ ... │ │ warp 0
... │ ... │ │ owns rows 0-31
lane 31 │ C[31,0]│ ... │ ... │ C[31,63]│ ┘
├────────┼────────┼────────┼────────┤
lane 32 │ C[32,0]│ ... │ ... │ C[32,63]│ ┐ warp 1
... │ ... │ │ owns rows 32-63
lane 63 │ C[63,0]│ ... C[63,63]│ ┘
├────────┼────────┼────────┼────────┤
lane 64 │ C[64,0]│ ... │ ┐ warp 2
... │ │ │ owns rows 64-95
lane 95 │ │ ┘
├────────┼────────┼────────┼────────┤
lane 96 │ C[96,0]│ │ ┐ warp 3
... │ │ │ owns rows 96-127
lane 127 │ C[127,0] ... C[127,63]│ ┘
└─────────────────────────────────────┘
Each TMEM cell is a full 32-bit slot. With FP32 accumulation each
cell holds one C[i, j] result. With smaller output dtypes the high
bits go unused. A single warp can only address 32 contiguous lanes
of TMEM via tcgen05.ld, so the 4 epilogue warps must each issue
loads against their respective 32-lane band to cover M = 128.
In Layout D each lane in the TMEM buffer holds a row of the output and each column holds a column of the output (the data is unpacked, so if the output data type of the MMA is < 32b the upper unused bits in each column just get ignored). So if our mma instruction computes an MxN block of output called C, TMEM[i][j] = C[i][j]. Thus for our row major MxN output tile, it's stored in M rows (lanes) across N columns in TMEM. The caveat which is demonstrated in the linked diagram is that a single warp can only access 32 lanes of TMEM, which means you need 4 warps to access a M=128 row-major output. This last point is also likely a hardware design artifact we aren't shielded from as I can't find a logical reason for requiring 4 warps to access all 128 lanes of TMEM.
In summary we now understand how the input data (A, B, SFA, SFB) is stored and provided to the mma instruction as well as how the output data is produced in the provided TMEM buffer (and how to properly access that data). The next step is to configure the mma instruction itself so it knows the dimensions of the input and output blocks, the data types, etc...
This is all configured through the instruction descriptor (i-desc). Official PTX docs on this descriptor field can be found here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-instruction-descriptor
We will look at the specific format for NVFP4 (all type formats can be seen at that link). For NVFP4:
The important fields are...
[2] : Sparsity (Dense = 0, Sparse = 1)
[4-5] : Matrix B Scale Factor Data ID
[13] : Negate A Matrix
[14] : Negate B Matrix
[15] : Transpose A Matrix
[16] : Transpose B Matrix
[17-22] : N, Dimension of Matrix B (3 LSBs not included)
[23] : Scale Matrix Type, for both scale_A / scale_B (UE4M3 = 0)
[27-28] : M, Dimension of Matrix A (7 LSBs not included)
[29-30] : Matrix A Scale Factor Data ID
[31] : K Dimension (Dense K=64 / Sparse K=128) = 0)
"Sparsity" - MMAs can operate on sparse or dense matrix formats, this is the selector (0 = Dense) "Matrix B Scale Factor Data ID" - Specifies the byte offset within each TMEM column where the scale factors are located to form the SF matrix. So for example if you are working with a format that uses 1 SF per row, the 32 row chunks are separated across separate TMEM columns (spaced apart by 4 "byte sized" columns). "Negate A Matrix" - Make A matrix elements negative (only supported for certain configs) "Negate B Matrix" - Make B matrix elements negative (only supported for certain configs) "Transpose A Matrix" - Transpose A then multiply (only supported for certain configs) "Transpose B Matrix" - Transpose B then multiply (only supported for certain configs) "N Dimension" - N in the KxN tile used for the MMA (only accepts certain values depending on data type and MMA format) "Scale Matrix Type" - UE8M0 = 1, UE4M3 = 0 "M Dimension" - M in the MxK tile used for the MMA (only accepts certain values depending on data type and MMA format) "Matrix A Scale Factor Data ID" - Same as for B, but for the A scale factor matrix "K Dimension" - K in MxK and KxN tiles for the MMA (the exact value is determined by the other configuration inputs to the MMA)
The last two fields in the mma instruction that we haven't touched on yet are scale_vectorsize and enable-input-d. enable-input-d is a predicate register containing a boolean value (false means zero out the d-tmem region of the output, so no accumulation. True means accumulate the result with the existing data in d-tmem). In otherwords we want enable-input-d to be false on the first k iteration, but true for all later iterations.
scale_vectorsize determines how many scale factors are used per row/col of the input tiles. The below image demonstrates how this parameter dictates the size of K for out input blocks. For all of our kernels we use the maximum number of scale factors supported by hardware, which is 4 (.scale_vec::4X or .block16, these are aliases). 4 scale factors per row/col of the A/B tiles results in K=64 for each of our tiles (this is because NVFP4 supports size 16 blocks per scale factor so 4*16 = 64 elements along the k-dimension).
To better understand the constraints on what size tiles can be computed we can look to this table: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-kind-shapes which shows us that for dense NVFP4 the MMA tile for N-dimension is limited to multiples of 8 and M is limited to 128. These can change slightly using something called 2SM or CTA pairs, which will be discussed in future kernels.
[Diagram at https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-block-scaling]
At this point we understand the basic inputs and outputs of the tcgen05.mma instruction as well as how to configure the instruction to perform the desired multiplication.
Next we can discuss how data will be moved. Specifically, how input data will be moved from GMEM->SMEM, SMEM->TMEM (for scale factors), and how output data will be moved from TMEM->Regs->GMEM. First I list the instructions used (along with their syntax) as well as a broad description of their functionality. In the code walkthrough section further down I go over exactly how these instructions are configured and use in the code.
GMEM->SMEM:
For this we use TMA (Tensor Memory Accelerator), this was briefly mentioned in the last few iterations of the GEMV kernel. which is really a name for the family of cp.async.bulk instructions in PTX. For out implementation we would like to copy a multi-dimensional tile so we will use the cp.async.bulk.tensor instructions. Plain cp.async.bulk instructions only support contiguous 1D transfers, not multi-dimensional strided transfers. Specifically, this version below allows us to copy from GMEM to DSMEM (Distributed Shared Memory of a CTA cluster). Copying to DSMEM allows us to use what is referred to as "TMA Multi-Cast" which allows us to copy one tensor segment from GMEM to the SMEM of multiple CTAs within a CTA cluster (DSMEM Details):
// global -> shared::cluster
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.cta_group}{.level::cache_hint}
[dstMem], [tensorMap, tensorCoords], [mbar]{, im2colInfo}
{, ctaMask} {, cache-policy}
.dst = { .shared::cluster }
.src = { .global }
.dim = { .1d, .2d, .3d, .4d, .5d }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.cta_group = { .cta_group::1, .cta_group::2 }
.load_mode = { .tile, .tile::gather4, .im2col, .im2col::w, .im2col::w::128 }
.level::cache_hint = { .L2::cache_hint }
.multicast = { .multicast::cluster }
We also need the variant which loads to a single CTA:
// global -> shared::cta
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.cta_group}{.level::cache_hint}
[dstMem], [tensorMap, tensorCoords], [mbar]{, im2colInfo} {, cache-policy}
.dst = { .shared::cta }
.src = { .global }
.dim = { .1d, .2d, .3d, .4d, .5d }
.completion_mechanism = { .mbarrier::complete_tx::bytes }
.cta_group = { .cta_group::1, .cta_group::2 }
.load_mode = { .tile, .tile::gather4, .im2col, .im2col::w, .im2col::w::128 }
.level::cache_hint = { .L2::cache_hint }
These instructions allow us to copy data from GMEM to SMEM in the format required. In the case of the NVFP4 data in A and B that format is columns of core matrices of appropriate size, and in the case of the scale factors it's the (32x4)x4 layout required by TMEM. Further details on how TMA works are available in the [TMA Details] section, and how it's used in the context of these transfers and data structuring is available in the code walkthrough section further down.
SMEM->TMEM:
As discussed, the mma instruction requires the scale factor data to be resident in TMEM. Thus we must copy the scale factor data from GMEM to SMEM, then from SMEM to TMEM. There is no way to copy data directly from GMEM into TMEM. Assuming the scale factor data is already present in SMEM we can use the tcgen05.cp instruction to copy it into TMEM.
tcgen05.cp.cta_group.shape{.multicast}{.dst_fmt.src_fmt} [taddr], s-desc;
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-cp
TMEM->Regs:
Once the results have been computed and accumulated in TMEM across the k-dimension and we are ready to store the results back to GMEM we must first move the results out of TMEM and into the register file. There is currently no way to move blocks of data directly from TMEM into SMEM or GMEM, the data must first be moved to registers and then any method of data movement out of the registers can be used. The instruction to move data from TMEM -> Registers is the tcgen05.ld instruction.
tcgen05.ld.sync.aligned.shape1.num{.pack}.b32 r, [taddr];
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld
I cover the details about both tcgen05.cp and tcgen05.ld in the code walkthrough. What's important to understand now is the broad function of these operations.
Regs->GMEM:
Finally we can use many methods to transfer the data from registers to GMEM. This includes simple coalesced stores, or we could store to SMEM first and initiate a TMA (cp.async.bulk) from SMEM -> GMEM. In this simple GEMM example we just use simple stores. In future kernels we will look at other methods and their potential benefits.
At this point we have a good idea of how data is moved and processed using tcgen05 instructions and tensor cores / tensor memory. Now we can revisit the design of our GEMM kernel with this information in mind:
For a single CTA this is the algorithm flow:
Assume we have allocated SMEM with enough room for the following:
A-block (mBxkB of nvfp4)
B-block (kBxnB of nvfp4)
SFA-block (mBx(kB/16) of e4m3)
SFB-block ((kB/16)xnB of e4m3)
Allocate TMEM (need room for SFA-block, SFB-block, C-block)
For each k-block (i.e. for every kB sized block along the k-dimension):
Load SFA-block, SFB-block into SMEM
Load A-block, B-block into SMEM
Copy SFA-block, SFB-block into TMEM
Execute mma
Move C-block into CTA regs
Move C-block from CTA regs to GMEM
Deallocate TMEM
We can think of this algorithm in terms of stages:
Stages: | 0 | 1 | 2 | 3 | 4 | Action: | GMEM->SMEM | SMEM->TMEM | Compute | TMEM->Regs | Regs->GMEM |
Stages 0-2 execute in a loop for each k-iteration (each k-iteration covers a kB sized block along K, this is the k-loop). CUTLASS refers to this as the "GEMM". Stages 3-4 execute after the k-loop once the CTA has fully computed the result tile. CUTLASS refers to this as the "epilogue".
For an example setup of how to use CUTLASS for a GEMM operation like this see submission_cutlass.py.
If we were to program this algorithm using just the tcgen05 instructions we've discussed thus far it would very likely fail for various reasons including non-deterministic correctness failures. It may also occasionally produce correct results. My operating systems professor would call these kinds of bugs "Heisenbugs" after Werner Heisenberg's uncertainty principle in quantum mechanics. Race conditions, for example, are a subset of these kinds of non-deterministic bugs. Like in systems programming, and unlike quantum mechanics (luckily for us :), these "Heisenbugs" are very often due to a failure to design software in a way that guarantees data is only consumed or overwritten when it's appropriate based on whatever guarantees the hardware has granted to the software. Without correct software design the output of the program can be impacted by operation timing and external processes.
In our case we have to ensure that no two stages are writing to and reading from the same memory (GMEM, SMEM, TMEM, Regs) at the same time. Specifically we need to create the following guarantee using software synchronization tools:
stage n completes before stage n+1 begins for n = 0 ... 3
stage 4 completes before stage 0 begins for the next k-iteration
From the above few lines we can see a "producer consumer" structure appear in that stage n produces data that is consumed by stage n + i where i > 0. This producer consumer structure is referred to heavily in many CUDA libraries because it's a very helpful structure to maintain synchronization guarantees in software as kernel algorithms become more complex.
For a more comprehensive overview of memory consistency and synchronization on NVIDIA GPUs see this section: [Memory Consistency and Synchronization Details] as well as the official documentation here: https://docs.nvidia.com/cuda/parallel-thread-execution/#memory-consistency-model. In the below description of how to create the guarantees mentioned above I'll just discuss the synchronization tools necessary and how they function instead of their more intricate and theoretical role in the overall memory consistency philosophy in NVIDIA GPU architecture.
Stage 0 (Produces GMEM data tiles in SMEM for current k-iteration) -> Stage 1 (Consumes SMEM scale factors tiles): Stage 0 uses TMA to transfer data from GMEM to SMEM. Thus we need to ensure the transfers have completed and the data is visible in SMEM to all threads participating in stage 1. Luckily there is a synchronization primitive to notify threads when transfers or operations have completed that works seamlessly with tcgen05 and cp.async.bulk operations. This primitive is called an mbarrier (short for memory barrier), also sometimes referred to in short hand as mbar objects. In the context of cp.async.bulk instructions the mbarrier primitive gives software the ability to both track and wait on the completion of data transfers. In particular, we can program an mbarrier object to expect x number of bytes, then pass that mbarrier to all of our TMA operations (note the mbar argument to the cp.async.bulk.tensor operations above). Those TMA operations will then notify that mbarrier object of how many bytes were transferred after the transfer is complete. I go over mbarriers in more detail in the glossary.
mbarriers also provide the ability to wait on them until the number of expected bytes have arrived. This is key to ensuring the threads in a CTA don't start stage 1 until the TMAs from stage 0 have completed. We use this to force our kernel to halt execution (yielding to other warps on the SM that can run) until the data from GMEM is available in SMEM.
Stage 1 (Produces scale factors in TMEM) -> Stage 2 (Consumes TMEM scale factors for tcgen05.mma operation): tcgen05.cp instructions are guaranteed to complete in program order before tcgen05.mma operations according to the tcgen05 pipelining rules here: <PTX Link> which means we actually don't need any explicit synchronization mechanism here.
After the k-loop has completed and the CTA has computed the final result in the result TMEM buffer...
Stage 2 (Produces MMA Result for this CTA in TMEM) -> Stage 3 (Consumes TMEM result to copy into registers): In order to ensure the last MMA has completed before we copy the result from TMEM to Regs we can again use an mbarrier. In addition to tracking if x number of bytes have been successfully transferred into SMEM, mbarriers can also track the completion of tcgen05 operations like tcgen05.mma. The tcgen05.commit instruction programs the given mbarrier to track the completion of all prior tcgen05.mma/cp/shift instructions in program order. Similar to what we did with stage 0 and 1 we can have stage 3 wait on this mbarrier which will force it to wait until all tcgen05.mma operations from stage 2 have completed, thus ensuring the data is ready in TMEM before stage 3 copies it into registers using tcgen05.ld.
In addition to the mbarrier we also need something called a fence. PTX has a number of fence types including memory and tcgen05. I discuss the different types of fences in more detail in the glossary. For our purposes here we need a tcgen05 fence which is used to ensure ordering of asynchronous tcgen05 instructions. Specifically, we need the tcgen05.fence::after_thread_sync which does the following per the PTX docs:
An asynchronous tcgen05 operation subsequent to a tcgen05.fence::after_thread_sync is ordered after all the prior tcgen05 and the execution ordering operations.
This means the tcgen05.ld instructions used by stage 3 to load the registers from TMEM can't be reordered or executed before the tcgen05.mma and tcgen05.commit instructions from stage 2 have completed. This is important because asynchronous tcgen05 operations may execute and complete in a different order than they were issued.
Stage 3 (Produces results in registers) -> Stage 4 (Consumes results in registers to store to GMEM): We need to ensure the tcgen05.ld instructions have completed before we store the results from the registers to GMEM. PTX provides us the ability to track the completion of tcgen05.ld operations (and others) via tcgen05.wait. Specifically, we need the tcgen05.wait::ld variant which tracks specifically tcgen05.ld operations.
Lastly due to how standard GMEM stores are ordered we know these will complete before meaningful work in the next k-iteration will start.
At this point we've covered all concepts and broad functionality needed to create a basic GEMM kernel using tcgen05 operations (leveraging tensor cores on the Blackwell architecture). Before we dive into a walkthrough of the code to cover some of the finer grained details we should discuss a few optimizations that are very powerful and don't change our code all that much.
Basic GEMM Optimizations
Our current kernel design is entirely sequential per CTA. In other words every stage (referring to the stages of data movement and computation above) both within and across k-iterations happens sequentially. As we've established it's required for the stages to progress sequentially within a single k-iteration to ensure correctness; however, there's no reason we can't overlap work across different k-iterations if each iteration is using it's own buffers in SMEM/TMEM.
For example let's say we use two sets of SMEM buffers (so there will be 8 buffers total per CTA, 2 sets of 4 buffers each (A, B, SFA, SFB)). Now instead of stage 0 of k-iteration 1 having to wait for stage 2 of k-iteration 0 to complete, it can proceed right away with loading the second set of buffers. Likewise stage 0 of k-iteration 2 can proceed once stage 2 of k-iteration 0 has completed instead of waiting on k-iteration 1. This technique of overlapping different stages of computation is also called pipelining. Since we've used pipelining in the context of hardware and PTX instructions already I'll refer to this kind of pipelining as software pipelining. The software pipeline described is of length 2, but this can be extended out to an N stage pipeline, and we can design our kernel to take N as a parameter. Below are a few diagrams demonstrating this kind of software pipelining.
Software pipelining diagram (N = 2-stage, sequential vs pipelined):
Sequential (single SMEM buffer set, 1 k-iter at a time):
k-iter 0: [Stage 0: GMEM→SMEM][Stage 1: SMEM→TMEM][Stage 2: MMA]
k-iter 1: [S0][S1][S2]
k-iter 2: [S0][S1][S2]
2-stage pipeline (two SMEM buffer sets, alternating):
buf 0: [Load k-iter 0 ][MMA k-iter 0 ] [Load k-iter 2 ][MMA k-iter 2 ]
buf 1: [Load k-iter 1 ][MMA k-iter 1 ]
└── overlap ────┘└─── overlap ───┘
N-stage pipeline (N sets of SMEM buffers): producer (TMA warp) can
race ahead by up to N k-iters before having to wait on the consumer
(MMA warp), giving deeper overlap at the cost of N× SMEM footprint.
┌────────────┐ mbar_tma[s] ┌────────────┐
│ TMA warp │ ──── signals on ────► │ MMA warp │
│ (producer) │ │ (consumer) │
│ │ ◄─── signals on ───── │ │
└────────────┘ mbar_mma[s] └────────────┘
fills SMEM[s] consumes SMEM[s]
(waits if MMA hasn't released) (waits if TMA hasn't filled)
Two mbarriers per stage are required so neither side overruns the
other once the pipeline is full.
In future kernels the concept of software pipelining will be extended to get even better overlap between different computational stages of the kernel.
The remaining question is, aside from allocating N sets of SMEM buffers for an N-stage software pipeline, how do we implement this conceptually? The answer is warp specialization [Warp Specialization Details]. In our pre-software pipelining kernel we only needed 4 warps per CTA so that we could access all of TMEM and perform sufficiently parallelized stores to GMEM from registers. When using tcgen05 most operations only require a single thread or single warp to initiate the work, thus kernels which primarily use tcgen05 don't actually need that many threads, since the threads aren't really doing the computational work themselves anymore, only initiating that work and tracking its completion. In pre-Blackwell or non-tcgen05 kernels the amount of work and problem shape can influence how many threads are launched and in what configuration. In tcgen05 kernels it's the structure of the kernel itself and what kind of work needs to be done (and not the amount or shape of the work) that dictates how many threads are launched per CTA. Warp specialization refers to figuring out how many structural pieces there are in the kernel, and assigning each of those pieces to one or a few warps. Warps in the kernel then only focus on the specific task within the kernel to which they are assigned.
In this case we have 5 computational stages in our kernel, but some of them can be grouped together. Stages 1 and 2 should be grouped together because tcgen05.cp and tcgen05.mma operations need to be performed in a loop if we decide to load multiple MMA tiles at a time from GMEM into SMEM. This will become clearer in the code walkthrough. We also can't overlap tcgen05.cp with tcgen05.mma as defined by the hardware. Stages 3 and 4 should be grouped together because we need to perform tcgen05.ld operations and GMEM stores from registers in a loop to avoid using too many registers in a single CTA (which could cause occupancy issues). The details of tcgen05.ld are discussed in the code walkthrough as well.
We end up with 3 groups of computational tasks in our kernel: Loading from GMEM->SMEM, Copying scale factors from SMEM->TMEM and performing MMA, and loading the result from TMEM into registers and storing to GMEM. Applying the concept of warp specialization is now simple, we just assign a warp of set of warps to each of these three tasks. We need four warps for the last task in order to access all of TMEM. The first two tasks only require issuing single thread or warp wide operations, so each only needs a single warp. This leaves us with a 6 warp kernel (or 192 threads).
Let's look at what each warp does and see how it implements the software pipelining described above for an N-stage software pipeline:
Warp 0 (TMA Warp, Stage 0): => Kernel setup => For k-iter in k-iterations:
=> Wait until SMEM buffers in software pipeline slot (k-iter % N) are free (i.e. not in use by the MMA warp for any previous k-iters)
=> Load data into buffers for k-iter
Warp 1 (MMA Warp, Stages 1 and 2): => Kernel setup => For k-iter in k-iterations:
=> Wait until SMEM buffers in software pipeline slot (k-iter % N) have their data loaded by the TMA warp
=> Copy scale factors from SMEM -> TMEM
=> Execute MMAs for this k-iter (accumulating in TMEM)
Warps 2-5 (Epilogue Warps, Stages 3 and 4): => Kernel setup => Wait for warp 1 to complete computation for all k-iterations => Move results from TMEM to Regs to GMEM
There are two things to note here: first, we need 2 mbarriers per software pipeline stage to enforce both that the consumer doesn't consume data before it has been produced and that the producer doesn't overwrite data still being consumed (before software pipelining we only needed the former). Second, we now can see how software pipelining has decoupled the work across k-iterations for warp 0 and warp 1. Since all buffers in the software pipeline are initially free warp 0 can steam ahead and load all of the buffers through k-iter N-1 without waiting for warp 1 to complete computation. Similarly, warp 1 won't have to wait as long (or at all) for warp 0 to load the data since those transactions happened while it was still computing a previous k-iter MMA.
There are many more optimizations available, and we will discuss them in the upcoming kernels. For now we lay the foundation of the next two kernels including the main optimization idea of warp specialization and software pipelining. In the next section I map the concepts discussed onto the actual code in order to dive a bit deeper on the functional details, and to help setup a background that will be helpful in understanding the subsequent kernels.
Code Walkthrough
In this walkthrough I'll be skipping over parts of the code that should be pretty self explanatory if one understands basic CUDA/C++.
- Data Details -
Crucial points for understanding the techinical details that follow in this section:
Data for A, B, SFA, and SFB are stored in k-major order (so row-major for A/SFA and column-major for B/SFB)
- Kernel Setup -
There are two important parts of the kernel setup: mbarrier initialization and tmem buffer allocation.
// Setup memory barriers
__shared__ alignas(8) int64_t mbars[PIPE_STAGES * 2 + 1];
const int mbar_addr_tma = static_cast<int>(__cvta_generic_to_shared(mbars));
const int mbar_addr_mma = mbar_addr_tma + PIPE_STAGES * 8; // 8 because each mbar is 64bits = 8B
const int mbar_addr_epi = mbar_addr_mma + PIPE_STAGES * 8;
if (warp_id == 0 && elect_one_sync()) {
for (int i = 0; i < PIPE_STAGES * 2 + 1; i++) {
mbar_init(mbar_addr_tma + i * 8, 1);
}
asm volatile("fence.mbarrier_init.release.cluster;");
} else if (warp_id == 1) {
tcgen05_alloc_tmem<1>(tmem_addr_ptr, TD_MMA_N*2);
}
__syncthreads(); // Ensure all threads have correct TMEM ptrs
In this code snippet we declare and initialize PIPE_STAGES * 2 + 1 mbarriers. As mentioned above we need 2 mbarriers per software pipeline stage to ensure correctness and we need one more additional mbarrier to make the epilogue warps wait for the results to be fully computed in TMEM before proceeding to store them to GMEM. We initialize the mbarriers with a count of 1 because only one thread needs to arrive on each barrier before it can be flipped and allow waiting threads to proceed (for more details on this see Synchronization Details). Only one thread is needed to perform the initializations, which introduces us to another function: elect_one_sync().
__device__ uint32_t inline elect_one_sync() {
uint32_t pred = 0;
asm volatile(
"{\\n"
".reg .pred %%px;\\n"
"elect.sync _|%%px, %1;\\n"
"@%%px mov.s32 %0, 1;\\n"
"}"
: "+r"(pred)
: "r"(0xFFFFFFFF)
);
return pred;
}
This function ultimately returns 1 (or true) if the executing thread has been "elected" and 0 (or false) otherwise. For a thread to be "elected" means it's the one unique thread chosen out of a subset of threads in warp. We do this with the elect.sync instruction which elects one predicated active leader thread from among a set of threads specified by membermask (in our case this mask includes all 32 threads in a warp). The laneid of the elected thread is returned in the 32-bit destination operand d. The sink symbol ‘_’ is used for destination operand d. The predicate destination p is set to True for the leader thread, and False for all other threads.
As we look over the rest of the code you will notice we use elect_one_sync() repeatedly, and you would be correct in assuming we may run into issues if a different thread in the warp is chosen each time this gets run. Luckily, election of a leader thread happens deterministically, i.e. the same leader thread is elected for the same membermask every time. The .sync qualifier indicates that this instruction causes the executing thread to wait until all threads in the membermask execute the elect instruction before resuming execution.
The last step in the initial kernel setup code snippet is TMEM buffer allocation. To understand this we look at tcgen05_alloc_tmem():
// Warp synchronous execution (all threads in a warp execute)
template<int CTA_GROUP>
__device__ void inline tcgen05_alloc_tmem(int *tmem_addr_ptr, const int n_cols) {
// Performs a cvt.u64.u32, enables proper passing of smem ptr to PTX assembly
const int tmem_addr_ptr_cvt = static_cast<int>(__cvta_generic_to_shared(tmem_addr_ptr));
asm volatile (
"tcgen05.alloc.cta_group::%2.sync.aligned.shared::cta.b32 [%0], %1;"
:
: "r"(tmem_addr_ptr_cvt), "r"(n_cols), "n"(CTA_GROUP)
);
}
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-alloc-manage-instructions
The tcgen05.alloc operation is a blocking instruction which dynamically allocates the specified number of columns in the Tensor Memory and writes the address of the allocated Tensor Memory into shared memory at the location specified by address operand dst. Thus tcgen05_alloc_tmem will attempt to allocate n_cols of TMEM, and put the TMEM address of that allocation into the smem address specified by tmem_addr_ptr. There are two important constraints when it comes to TMEM allocation:
1 - Partial column allocations aren't possible, only whole columns can be allocated. 2 - 32 is the minimum number of columns that can be allocated for a single tcgen05.alloc call, and the number of columns must be a power of 2 for a single allocation.
In our kernel we really need three TMEM buffers: a result buffer, SFA buffer, and SFB buffer. We could choose to do two or three separate allocations for each of these buffers, but in most kernel configurations we don't gain anything by doing that. Instead, we make one allocation of TD_MMA_N*2 columns (or twice the size of the result buffer). Then we distribute that allocated buffer to different TMEM pointers to be used for different purposes. This way we still get granularity of TMEM buffers (avoiding unused space) while reducing the number of tmem allocation calls which could become expensive, especially if multiple CTAs share a single SM.
- TMA Warp -
After kernel setup the pseudo-code for the TMA warp looks like this: => For k-iter in k-iterations:
=> Wait until SMEM buffers in software pipeline slot (k-iter % N) are free (i.e. not in use by the MMA warp for any previous k-iters)
=> Load data into buffers for k-iter
The actual code to implement this looks like the following:
// TMA thread loops over SMEM tile stages and loads from GMEM->SMEM
if (warp_id == TMA_WARP && elect_one_sync()) {
auto tma_load_stage = [&](const int k_off, const int stage) {
const int k_off_coremat = k_off / WIDTH_COREMAT;
const int a_smem_stage_ptr = a_smem_ptr + stage * A_SMEM_TILESZ;
const int b_smem_stage_ptr = b_smem_ptr + stage * B_SMEM_TILESZ;
const int sfa_smem_stage_ptr = sfa_smem_ptr + stage * SF_SMEM_TILESZ;
const int sfb_smem_stage_ptr = sfb_smem_ptr + stage * SF_SMEM_TILESZ * (TD_MMA_N == 256 ? 2 : 1);
const int mbar_addr_tma_stage = mbar_addr_tma + stage * 8;
tcgen05_3dtma_g2s_ab<1>(a_smem_stage_ptr, &tmap_a, m_off, k_off_coremat, mbar_addr_tma_stage, CacheHintSm100::EVICT_NORMAL);
tcgen05_3dtma_g2s_ab<1>(b_smem_stage_ptr, &tmap_b, n_off, k_off_coremat, mbar_addr_tma_stage, CacheHintSm100::EVICT_NORMAL);
/*
Scale factors are stored in global memory in 4x4x32 chunks, i.e. 512B chunks where each chunk represents a
128x4 chunk of the SF matrix (in M or N xK)
So we calculate the offset in each dimension in terms of these 512B chunks:
k_off / 64 represents the number of 128x4 (512B) chunks along the K dimension which are contiguous (4 * SF_BLOCKS_SIZE = 64)
m/n_off / 128 represents the number of 512B chunks along the M dimension which are strided by K / 64 512B chunks
*/
const uint8_t* sfa_g_ptr = sfa_gmem_base + ((k_off / 64) + (m_off / 128) * (K / 64)) * 512; // ISSUE: These could be just simple bit shifts, adjust if compiler doesn't
const uint8_t* sfb_g_ptr = sfb_gmem_base + ((k_off / 64) + (n_off / 128) * (K / 64)) * 512;
tcgen05_1dtma_g2s_sf(sfa_smem_stage_ptr, sfa_g_ptr, SF_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_NORMAL);
tcgen05_1dtma_g2s_sf(sfb_smem_stage_ptr, sfb_g_ptr, SF_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_NORMAL);
if constexpr (TD_MMA_N == 256) {
tcgen05_1dtma_g2s_sf(sfb_smem_stage_ptr+SF_SMEM_TILESZ, sfb_g_ptr + (K / 64)*512, SF_SMEM_TILESZ, mbar_addr_tma_stage, CacheHintSm100::EVICT_NORMAL);
}
// Signal in mbarrier that we expect SMEM_TILE_SZ bytes to arrive on this mbar object before proceeding to next phase
mbar_arrive_expect(mbar_addr_tma_stage, SMEM_TILE_SZ);
};
// Fill the TMA pipe
for (int stage = 0; stage < PIPE_STAGES; stage++) {
tma_load_stage(stage * TD_SMEM_K, stage);
}
// Cycle through tile stages, loading tiles once no longer in use by the MMA stage
int stage = 0;
for (int k_off = TD_SMEM_K * PIPE_STAGES; k_off < K; k_off += TD_SMEM_K) {
mbar_wait(mbar_addr_mma + stage * 8, (((k_off / TD_SMEM_K) / PIPE_STAGES) - 1) % 2);
tma_load_stage(k_off, stage);
stage = (stage + 1) % PIPE_STAGES;
}
}
We define a lambda function tma_load_stage which captures all references and takes in k_off (which is just k-iter * size of k-block) and stage (which determines which buffers in the software pipeline we are using for this k-iteration). In the lambda we find the corresponding buffers and mbarriers for the stage. Notice we also compute k_off_coremat which determines the offset of core matrix chunks along the k-dimension for the current k-iteration. Next we initiate the TMA transfers for A, B, SFA, and SFB using the following two functions: tcgen05_3dtma_g2s_ab() and tcgen05_1dtma_g2s_sf()
__device__ void inline tcgen05_1dtma_g2s_sf(int dst, const void *src, int size, int mbar_addr, CacheHintSm100 cache_policy) {
asm volatile(
"cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;"
:
: "r"(dst), "l"(src), "r"(size), "r"(mbar_addr), "l"(cache_policy));
}
template<int CTA_GROUP>
__device__ void inline tcgen05_3dtma_g2s_ab(int dst_smem, const void *tmap_ptr, int mn_off, int k_off_coremat, int mbar_addr, CacheHintSm100 cache_policy) {
asm volatile (
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.cta_group::%0.L2::cache_hint [%1], [%2, {%3, %4, %5}], [%6], %7;"
:
: "n"(CTA_GROUP), "r"(dst_smem), "l"(tmap_ptr), "r"(0), "r"(mn_off), "r"(k_off_coremat), "r"(mbar_addr), "l"(cache_policy)
);
}
We've discussed these instructions previously at a broad level, now we can dive into the technical details of how they work in the context of this kernel.
cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint
This instruction without the ".tensor" modifier performs a contiguous transfer of size bytes. The .shared::cta.global modifiers indicate we are copying from global memory into CTA local SMEM (as opposed to DSMEM). mbarrier::complete_tx::bytes indicates we are passing an mbarrier to which this transaction must report the number of bytes transferred upon completion (more details on this in the glossary). L2::cache_hint allows us to provide hints as to what data will be re-used or not. This way evictions can more intelligently optimize for temporal cache locality (see Cache Hint Details).
cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.cta_group::%0.L2::cache_hint
First refer to the [TMA Details] segment for an understanding of TMA descriptors and how these cp.async.bulk.tensor instructions interact with mbarriers. This instruction and wrapper assume we've properly setup a TMA descriptor for the data we are transferring. Most modifiers in the instruction serve the same purpose as the prior bulk transfer instruction. The main differences include .3d which indicates the tensor (or block of data) we are transferring has 3 dimensions and shared::cluster instead of shared::cta which means we are copying to DSMEM instead of SMEM. This instruction also allows using shared::cta; however, later we will find that using the cluster variant is useful for further optimizations, and when transferring to a single CTA instead of multiple CTAs DSMEM it effectively reduces to the same operation and no performance hit. The last difference for this tensor variant of TMA transfer is the addition of the int mn_off and int k_off_coremat fields.
To understand those arguments we look at how the 3d transaction is structured by the tensor map (no-swizzle example):
constexpr uint32_t rank = 3;
uint64_t dim_gmem[rank] = {32, mn_dim_gmem, k_dim_gmem/32};
uint64_t stride_gmem[rank - 1] = {k_dim_gmem/2, 16};
uint32_t dim_smem[rank] = {32, MN_SMEM_TD, K_SMEM_TD/32};
The three dimensions of our matrix in GMEM (dim_gmem) look like: {core matrix width, mn-dimension, number of core matrices along the k-dim}. Also embedded in the tensor map are the sizes along each dimension that we are copying to SMEM/DSMEM (dim_smem): {core matrix width, mn-block-dimension, number of core matrices along the k-block-dim}. We always copy the full width of the core matrix so dim_gmem[0] always equals dim_smem[0], and the other two dimensions reflect the full GMEM dimensions vs the dimensions of a single load tile. Again for a more detailed description of this see [TMA Details].
Going back to the wrapper function tcgen05_3dtma_g2s_ab above we see that off the three offsets we pass, the first is always 0. This is because we never load a block starting somewhere in the middle of a core matrix. The argument mn_off determines how far along the mn_dim our tile is offset, and k_off_coremat determines how many "chunks" of core matrices the tile to be loaded is offset in GMEM (this is just the k offset divided by the core matrix width).
One may wonder why we are allowed to use a contiguous transfer through tcgen05_1dtma_g2s_sf for the scale factors given they need to have that unusual 32x4x4 layout discussed in the previous section. Luckily one of the inputs to the kernel in this competition were separate scale factor tensors that had been rearraned into the required format, so they the oddly layed out tiles can just be loaded contiguously (See Data Layout section for further details). On this theme one quirk of this segment of code is the ugly if constexpr (TD_MMA_N == 256) statement, which is necessary when loading tiles of the SFB matrix when the N-dimension of the tile is 256. It's the structure of the reformatted tensors that necessitates this double TMA issue because the tiles of 32x4x4 scale factors represent contiguous 128x4 tiles of the original scale factor tensors. That means 256x4 scale factor tiles are non-contiguous in both the standard scale factor tensors and the reformatted tensors. Another solution to this is to use a 2d transfer much in the same way we use a 3d transfer to achieve the core matrix structure for the nvfp4 data. That involves trade-offs that will be discussed in the upcoming kernels (TMA setup overhead, tensor bulk transfers sometimes take longer than a series of contiguous transfers, etc...).
Once we've issued all of the data transfers for a k-block we need to ensure the mbarrier we are using for the current software pipeline stage is expecting the correct number of bytes. This will allow the mbarrier to signal to the MMA warp waiting on this data that it can proceed with computation once all data has arrived.
// Signal in mbarrier that we expect SMEM_TILE_SZ bytes to arrive on this mbar object before proceeding to next phase
mbar_arrive_expect(mbar_addr_tma_stage, SMEM_TILE_SZ);
__device__ inline void mbar_arrive_expect(const int mbar_addr, const int size) {
asm volatile(
"mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
:
: "r"(mbar_addr), "r"(size)
: "memory"
);
}
See the [Synchronization Details] section for specifics on arrive-on and expect-tx sub-operations performed by the mbarrier.arrive PTX instruction. In short, mbarrier.arrive.exptect_tx tells the mbarrier to expect x number of bytes from the previously issued TMA transactions for this k-iteration where x is the total bytes in all tiles of A, B, SFA, and SFB.
This completes the description of the lambda function. Now we can look at how it's used in the k-loop by the TMA warp:
// Fill the TMA pipe
for (int stage = 0; stage < PIPE_STAGES; stage++) {
tma_load_stage(stage * TD_SMEM_K, stage);
}
// Cycle through tile stages, loading tiles once no longer in use by the MMA stage
int stage = 0;
for (int k_off = TD_SMEM_K * PIPE_STAGES; k_off < K; k_off += TD_SMEM_K) {
mbar_wait(mbar_addr_mma + stage * 8, (((k_off / TD_SMEM_K) / PIPE_STAGES) - 1) % 2);
tma_load_stage(k_off, stage);
stage = (stage + 1) % PIPE_STAGES;
}
In the first loop we "prime the pump" and initiate loads for all software pipeline stages. As discussed in the previous section we don't need to wait in this loop because we aren't overwritting any meaningful data at the start of the k-loop, the data in SMEM is just garbage. This loop is also a large part of why a lambda function is used: to avoid a bunch of conditional code that makes the core logic hard to follow.
In the second loop the only difference is now we need to execute mbar_wait before loading tiles for a k-iteration. This is because we could be overwriting data that is still being consumed by the MMA warp for computation. The mbarriers pointed to by mbar_addr_mma protect the producer (TMA warp) from overwriting data still in use by the consumer (MMA warp).
__device__ void mbar_wait(const int mbar_addr, const int phase) {
uint32_t ticks = 0x989680; // expiration date for try wait to re-try, from CUTLASS
asm volatile(
"{\\n"
".reg .pred P1;\\n"
"LAB_WAIT:\\n"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2;\\n" // Acquire semantics assumed here
"@P1 bra.uni DONE;\\n" // Add .uni here because there won't be warp divergence
"bra.uni LAB_WAIT;\\n"
"DONE:\\n"
"}"
:
: "r"(mbar_addr), "r"(phase), "r"(ticks)
);
}
In the [Synchronization Details] section I discuss how mbarriers have a binary phase that threads can use to determine if they should wait or proceed. That phase only flips if the conditions of the mbarrier have been met (i.e. all threads have arrived on the barrier and all bytes for async transactions have also arrived). This wrapper function is essentially a spin-lock that checks if the barrier has flipped and the producer or consumer is free to continue execution for this software pipeline stage. The try_wait modifier indicates that this instruction can block the executing warp (as opposed to test_wait which is non-blocking). This is ideal for us because we want warps that are blocked to yield to other warps on the SM that can continue executing. This idea is similar to yielding threads in operating systems where we want to avoid threads spinning on a resource hogging CPU cycles and inserting bubbles into hardware compute pipelines.
- MMA Warp -
After kernel setup the pseudo-code for the MMA warp looks like this: => For k-iter in k-iterations:
=> Wait until SMEM buffers in software pipeline slot (k-iter % N) have their data loaded by the TMA warp
=> Copy scale factors from SMEM -> TMEM
=> Execute MMAs for this k-iter (accumulating in TMEM)
The actual code to implement this looks like the following:
// MMA thread loops over SMEM tile stages, loads TMEM and computes MMA ops
else if (warp_id == MMA_WARP && elect_one_sync()) {
for (int k_off = 0; k_off < K; k_off += TD_SMEM_K) {
const int stage = (k_off / TD_SMEM_K) % PIPE_STAGES;
mbar_wait(mbar_addr_tma + stage * 8, ((k_off / TD_SMEM_K) / PIPE_STAGES) % 2);
const int a_smem_stage_ptr = a_smem_ptr + stage * A_SMEM_TILESZ;
const int b_smem_stage_ptr = b_smem_ptr + stage * B_SMEM_TILESZ;
const int sfa_smem_stage_ptr = sfa_smem_ptr + stage * SF_SMEM_TILESZ;
const int sfb_smem_stage_ptr = sfb_smem_ptr + stage * SF_SMEM_TILESZ * (TD_MMA_N == 256 ? 2 : 1);
const int mbar_addr_mma_stage = mbar_addr_mma + stage * 8;
// Load scale factors SMEM -> TMEM
for (int sub_k_iter = 0; sub_k_iter < TD_SMEM_K / TD_MMA_K; sub_k_iter++) {
uint64_t sfa_desc = make_smem_desc<0, false>(sfa_smem_stage_ptr + (sub_k_iter * 512)); // ISSUE: verify this should input 0 here
uint64_t sfb_desc = make_smem_desc<0, false>(sfb_smem_stage_ptr + (sub_k_iter * 512));
tcgen05_cp<1>(tmem_addr_sfa + 4 * sub_k_iter, sfa_desc);
tcgen05_cp<1>(tmem_addr_sfb + MAX<TD_MMA_N / 32, 4>() * sub_k_iter, sfb_desc);
if constexpr (TD_MMA_N == 256) {
uint64_t sfb_desc2 = make_smem_desc<0, false>(sfb_smem_stage_ptr + SF_SMEM_TILESZ + (sub_k_iter * 512));
tcgen05_cp<1>(tmem_addr_sfb + 8 * sub_k_iter + 4, sfb_desc2);
}
}
// Loop over SMEM tile K-dim
for (int sub_k_iter = 0; sub_k_iter < TD_SMEM_K / TD_MMA_K; sub_k_iter++) {
// Stride computed differently depending on swizzle mode because it changes core matrix shape
uint64_t a_desc, b_desc;
if constexpr (SWIZZLE) {
a_desc = make_smem_desc<TD_MMA_M, SWIZZLE>(a_smem_stage_ptr + sub_k_iter * 32);
b_desc = make_smem_desc<TD_MMA_N, SWIZZLE>(b_smem_stage_ptr + sub_k_iter * 32);
}
else {
a_desc = make_smem_desc<TD_MMA_M, SWIZZLE>(a_smem_stage_ptr + sub_k_iter * TD_MMA_K * TD_MMA_M / 2);
b_desc = make_smem_desc<TD_MMA_N, SWIZZLE>(b_smem_stage_ptr + sub_k_iter * TD_MMA_K * TD_MMA_N / 2);
}
int sfa_tmem = tmem_addr_sfa + 4 * sub_k_iter;
int sfb_tmem;
if constexpr (TD_MMA_N == 256) {
sfb_tmem = tmem_addr_sfb + 8 * sub_k_iter;
} else {
sfb_tmem = tmem_addr_sfb + 4 * sub_k_iter + (n_off%128)/32;
}
tcgen05_mma_nvfp4<1>(tmem_addr_result, a_desc, b_desc, make_instr_desc<TD_MMA_M, TD_MMA_N>(), sfa_tmem, sfb_tmem, k_off + sub_k_iter); // Inputting k_off like this will set enable-input-d so only on the first mma we 0 out the result space in TMEM
}
// signal MMA done
tcgen05_commit(mbar_addr_mma_stage);
}
// Signal epilogue to start
tcgen05_commit(mbar_addr_epi);
}
Since all of the code executed by the MMA warp is single thread execution we guard the entire block with elect_one_sync().
As the consumer warp we must first wait for the loads issued by the TMA warp for the current software pipeline stage to complete before proceeding with computation. Thus, we again use mbar_wait but this time on the mbarriers referenced via mbar_addr_tma which are used to prevent the consumemr warp (MMA) from using data before it has been produced (by the TMA warp).
Next we setup our SMEM, TMEM, mbarrier pointers as we did in the TMA warp. Now we are ready to copy the scale factor data from SMEM to TMEM using tcgen05_cp
// Copies 32 rows x 128 bits from matrix described in SMEM by desc
// into tmem_ptr
template<int CTA_GROUP>
__device__ void inline tcgen05_cp(int tmem_ptr, uint64_t desc) {
asm volatile (
"tcgen05.cp.cta_group::%2.32x128b.warpx4 [%0], %1;"
:
: "r"(tmem_ptr), "l"(desc), "n"(CTA_GROUP)
);
}
tcgen05.cp.cta_group.shape{.multicast}{.dst_fmt.src_fmt} [taddr], s-desc;
.cta_group = { .cta_group::1, .cta_group::2 }
.src_fmt = { .b6x16_p32 , .b4x16_p64 }
.dst_fmt = { .b8x16 }
.shape = { .128x256b, .4x256b, .128x128b, .64x128b**, .32x128b*** }
.multicast = { .warpx2::02_13** , .warpx2::01_23**, .warpx4*** }
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-cp
Instruction tcgen05.cp initiates an asynchronous copy operation from shared memory to the location specified by the address operand taddr in the Tensor Memory.
This is where we need to start using matrix descriptors (described in the above sections).
// Complete descriptor with address info
template<int MN_DIM, bool SWIZZLE_128B = false>
__device__ uint64_t inline make_smem_desc(int smem_addr) {
constexpr uint64_t LBO = SWIZZLE_128B ? 1 : MN_DIM*16;
constexpr uint64_t SBO = SWIZZLE_128B ? 8 * 128 : 8 * 16;
constexpr uint64_t SWIZZLE_BITS = SWIZZLE_128B ? 2 : 0;
return encode(smem_addr) | (encode(LBO) << 16) | (encode(SBO) << 32) | (0x1ULL << 46) | (SWIZZLE_BITS << 61);
}
This function encodes the fields into their associated bit positions.
s-desc is a matrix descriptor (the same kind as discussed previously); however, in this case we aren't using any swizzling and we hardcode the LBO (Leading Byte Offset) to be 0. We don't swizzle the scale factors for now because the benefit would be relatively small for the complexity of implementation (considering that scale factors make up only a small percentage of the overall data). LBO is 0 because for the tcgen05.cp instruction we are essentially copying one long column of core matrices. SBO stays the same for non-swizzled data.
A logical next question would be: how does tcgen05.cp know how much data to copy? The answer lies in the .shape modifier which takes the form of .{rows}x{bits}b where rows is the number of TMEM rows being copied into and bits is the number of bits to copy into each row. There are a number of options; we choose .32x128b which means we're copying 32 rows of 128bits each worth of data from SMEM (as a single column, which works because the data has already been reformatted as contiguous 32x4x4 chunks) into 32 rows of TMEM (across 128b / 32b per column = 4 columns). Another important field to touch on is the .multicast modifier. The value for that modifier will depend on the parameter chosen for .shape and it indicates how data will be replicated across rows of TMEM. Recall in an earlier section I briefly mentioned how data must sometimes be repeated across TMEM rows likely due to constraints imposed by hardware. We see this explicitly now as the .multicast qualifier indicates how many chunks of repeated data get copied into TMEM rows. For example, in our case we are required to use the .warpx4 modifier which indicates our data will be copied down 4 times across all rows of TMEM.
We perform the scale factor copies in a loop the size of TD_SMEM_K / TD_MMA_K. TD_SMEM_K represents the size in the k-dimension of the TMA loaded tiles. TD_MMA_K represents the size of the k-dimension used in a single one of our MMA instructions. TD_SMEM_K must be a multiple of TD_MMA_K; however, we can experiment with the exact values to find the most optimal values for given problem shapes. TD_MMA_K can at it's largest for the sm100a architecture (specific architecture for the Nvidia B200) be 64. Thus, if TD_SMEM_K is 256 each TMA loaded tile loads 4 tcgen05.mma instructions worth of data. Since each tcgen05.cp instruction loads the scale factors for one 32x4x4 block (or one 128x4 block if using the canonical dimensions) this means it loads the scale factors for a tile with k = 64 (since 4 scale factors spans 4*16 = 64 NVFP4 elements). In other words it loads the scale factors for a single mma. This requires us to loop through and call multiple tcgen05.cp instructions for each k=64 chunk of the loaded TMA tile.
The second for loop iterates across the TD_SMEM_K width tiles loaded by the TMA warp in chunks of TD_MMA_K executing the tcgen05.mma on the data and accumulating the result in TMEM.
tcgen05_mma_nvfp4<1>(tmem_addr_result, a_desc, b_desc, make_instr_desc<TD_MMA_M, TD_MMA_N>(), sfa_tmem, sfb_tmem, k_off + sub_k_iter);
// Single thread execution
template<int CTA_GROUP>
__device__ void inline tcgen05_mma_nvfp4(int d_tmem, uint64_t a_desc, uint64_t b_desc, uint32_t i_desc,
int sfa_tmem, int sfb_tmem, int enable_input_d) {
asm volatile (
"{\\n"
".reg .pred p;\\n"
"setp.ne.b32 p, %4, 0;\\n"
"tcgen05.mma.cta_group::%7.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \\n"
"}\\n"
:
: "r"(d_tmem), "l"(a_desc), "l"(b_desc), "r"(i_desc), "r"(enable_input_d),
"r"(sfa_tmem), "r"(sfb_tmem), "n"(CTA_GROUP)
);
}
We pass in the TMEM address of our result, appropriately defined matrix descriptors for the A and B tiles in SMEM, the compile time defined instruction descriptor (discussed earlier), TMEM addresses of the appropriate scale factor chunks, and k_off + sub_k_iter which will set enable_input_d to zero out the TMEM result buffer on the first computation and accumulate thereafter. We don't need to worry about tcgen05.mma instructions corrupting the results of previous tcgen05.mma instructions because hardware provides a guarantee of safety for sequential execution (hardware pipelining rules discussed earlier).
One imporant aside is that accumulation happens in TMEM using FP32, so each result element takes the width of a full column in TMEM, and in the epilogue we need to convert these FP32 to FP16 values before storing the values to GMEM.
Additionally, it makes sense to reiterate that the MMA shapes we can compute are limited by this table: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-kind-shapes, which determines our choice of TD_MMA_M/N/K for the kernel, and those choices influence the various shapes and instruction variants used within the kernel.
The last step in the MMA warp is to notify other warps when relevant MMAs have finished computing. We have to notify the TMA warp when we have finished consuming the data for a given software pipeline stage, and we need to notify the epilogue warps when all computation has finished and the results are ready to be consumed.
In the TMA Warp section we discussed how after the first loop through the software pipeline the TMA warps must wait on the associated mbar_addr_mma mbarrier to ensure that data is safe to overwrite (not still being consumed by the MMA warp). This code is how we signal that mbarrier so TMA knows it can proceed:
// Signal epilogue to start
tcgen05_commit(mbar_addr_mma_stage);
__device__ void inline tcgen05_commit(const int mbar_addr) {
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
:
: "r"(mbar_addr)
: "memory"
);
}
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit
The instruction tcgen05.commit is an asynchronous instruction which makes the mbarrier object, specified by the address operand mbar, track the completion of all the prior asynchronous tcgen05 operations (including tcgen05.mma) initiated by the executing thread. Upon the completion of the tracked asynchronous tcgen05 operations, the mbarrier is signaled to complete (assuming the count of the mbarrier was initialized to one, which it always will be for this code. See Synchronization Details for more).
At the end of each k-iteration we call tcgen05_commit(mbar_addr_mma_stage); to signal to the TMA Warp that the MMA warp is finished with the current pipeline stage and that data is safe to be overwritten.
Once we've completed computation across all k-iterations and the result in TMEM is ready to be written back to GMEM we signal to the epilogue warp via tcgen05_commit(mbar_addr_epi); which as we will see in the next section allows the epilogue warps to proceed.
- Epilogue Warps -
After kernel setup the pseudo-code for the Epilogue warps looks like this: => Wait for MMA warp to complete computation for all k-iterations => Move results from TMEM to Regs to GMEM
The code:
// All warps aside from the two for TMA/MMA are dedicated to the epilogue
else if (warp_id < NUM_WARPS - 2) {
mbar_wait(mbar_addr_epi, 0);
asm volatile("tcgen05.fence::after_thread_sync;");
// Load MMA into regs (TMEM -> Regs)
// Each warp loads 16 rows per tcgen05_ld and we have 128 rows, with 4 warps each one is responsible for 32 rows
// Each thread
float results[TD_MMA_N / 2];
int rows_per_warp = TD_MMA_M / (NUM_WARPS - 2);
for (int sub_m = 0; sub_m < rows_per_warp / 16; sub_m++) {
if constexpr (TD_MMA_N == 256) {
tcgen05_ld<16, 256, 32>(results, tmem_addr_result + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
}
else if constexpr (TD_MMA_N == 128) {
tcgen05_ld<16, 256, 16>(results, tmem_addr_result + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
}
else if constexpr (TD_MMA_N == 64) {
tcgen05_ld<16, 256, 8>(results, tmem_addr_result + (((warp_id * rows_per_warp) + sub_m * 16) << 16));
}
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Post process and store from Regs to SMEM (Regs -> SMEM)
// Transfer result from SMEM -> GMEM (8 comes from 256/32 -> 256b per ld block from above)
for (int i = 0; i < TD_MMA_N / 8; i++) {
const int m_offset = m_off + warp_id * rows_per_warp + sub_m * 16 + lane_id / 4;
const int n_offset = n_off + i * 8 + (lane_id % 4) * 2;
reinterpret_cast<half2 *>(c_ref + (m_offset)*N + n_offset)[0] = __float22half2_rn({results[i * 4], results[i * 4 + 1]});
reinterpret_cast<half2 *>(c_ref + (m_offset + 8)*N + n_offset)[0] = __float22half2_rn({results[i * 4 + 2], results[i * 4 + 3]});
}
}
At the end of the previous section we saw how the MMA warp signals an mbarrier on which the epilogue warps are waiting once the computation phase is complete. The first step of the epilogue code is then to wait on this mbarrier with mbar_wait(mbar_addr_epi, 0);. This barrier only flips once so hard-coding the phase to 0 is fine (see Synchronization Details for more on mbarriers and phases).
Next we insert a fence to enforce that the following tcgen05_ld calls are ordered strictly after all prior asynchronous tcgen05 instructions have completed: asm volatile("tcgen05.fence::after_thread_sync;");
The final block of code is what actually moves the data from TMEM into registers then to GMEM. To facilitate TMEM -> Registers we need the tcgen05.ld instruction:
This is a templated function so we can compile in different versions based on the compile time parameters chosen for our kernel. Below is just one example implementation.
template<int LANES, int WIDTH, int REPT>
__device__ void inline tcgen05_ld(float* regs, int tmem_addr);
template<>
__device__ void inline tcgen05_ld<16, 256, 16>(float* regs, int tmem_addr) {
asm volatile (
"tcgen05.ld.sync.aligned.16x256b.x16.b32 { %0, %1, %2, %3, %4, %5, %6, %7, "
" %8, %9, %10, %11, %12, %13, %14, %15, "
" %16, %17, %18, %19, %20, %21, %22, %23, "
" %24, %25, %26, %27, %28, %29, %30, %31, "
" %32, %33, %34, %35, %36, %37, %38, %39, "
" %40, %41, %42, %43, %44, %45, %46, %47, "
" %48, %49, %50, %51, %52, %53, %54, %55, "
" %56, %57, %58, %59, %60, %61, %62, %63}, [%64];"
: "=f"(regs[0]), "=f"(regs[1]), "=f"(regs[2]), "=f"(regs[3]),
"=f"(regs[4]), "=f"(regs[5]), "=f"(regs[6]), "=f"(regs[7]),
"=f"(regs[8]), "=f"(regs[9]), "=f"(regs[10]), "=f"(regs[11]),
"=f"(regs[12]), "=f"(regs[13]), "=f"(regs[14]), "=f"(regs[15]),
"=f"(regs[16]), "=f"(regs[17]), "=f"(regs[18]), "=f"(regs[19]),
"=f"(regs[20]), "=f"(regs[21]), "=f"(regs[22]), "=f"(regs[23]),
"=f"(regs[24]), "=f"(regs[25]), "=f"(regs[26]), "=f"(regs[27]),
"=f"(regs[28]), "=f"(regs[29]), "=f"(regs[30]), "=f"(regs[31]),
"=f"(regs[32]), "=f"(regs[33]), "=f"(regs[34]), "=f"(regs[35]),
"=f"(regs[36]), "=f"(regs[37]), "=f"(regs[38]), "=f"(regs[39]),
"=f"(regs[40]), "=f"(regs[41]), "=f"(regs[42]), "=f"(regs[43]),
"=f"(regs[44]), "=f"(regs[45]), "=f"(regs[46]), "=f"(regs[47]),
"=f"(regs[48]), "=f"(regs[49]), "=f"(regs[50]), "=f"(regs[51]),
"=f"(regs[52]), "=f"(regs[53]), "=f"(regs[54]), "=f"(regs[55]),
"=f"(regs[56]), "=f"(regs[57]), "=f"(regs[58]), "=f"(regs[59]),
"=f"(regs[60]), "=f"(regs[61]), "=f"(regs[62]), "=f"(regs[63])
: "r"(tmem_addr)
);
}
tcgen05.ld.sync.aligned.shape1.num{.pack}.b32 r, [taddr];
.shape1 = { .16x64b, .16x128b, .16x256b, .32x32b }
.num = { .x1, .x2, .x4, .x8, .x16, .x32, .x64, .x128 }
.pack = { .pack::16b }
PTX Link: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld
Instruction tcgen05.ld asynchronously loads data from the Tensor Memory at the location specified by the 32-bit address operand taddr into the destination register r, collectively across all threads of the warps. The .shape qualifier and the .num qualifier together determines the total dimension of the data which is loaded from the Tensor Memory. The .shape qualifier indicates the base dimension of data to be accessed. The .num qualifier indicates the repeat factor on the base dimension resulting in the total dimension of the data that is accessed.
Sticking with the example we look at .shape1 = .16x256b and .num = .x16 which means load from tensor memory using the .16x256b structure repeated across the columns of TMEM 16 times. Below is the diagram which shows how a single warp loads data from TMEM into the registers of its threads for this shape:
[Insert diagram from https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b]
So each tcgen05.ld instruction executed by all threads in a warp loads 16 rows x 8 columns (this is where 16x256b comes from, 16 rows and 256b = 8, 32b width columns in TMEM). Then with the .x16 modifier we end up loading 16 rows by 128 columns of TMEM for a single warp, with each thread holding 64 of the total 16*128 values in its registers. This is why we declare a local array of floats float results[TD_MMA_N / 2]; to hold the results of the tcgen05.ld operation.
There are different formats for tcgen05.ld which change the number of values per thread and the format of the load from TMEM. Those formats can be seen on the PTX docs where I sourced the above graphic: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
In this particular case we have 4 warps which each load 16 rows from TMEM, but that's only half of the rows in TMEM. For most of our kernels we choose the MMA M dimension to be 128 so we need to load all rows of TMEM into registers, which means we need a loop in case warps have to issue multiple tcgen05.ld instructions to load all 32 rows per warp and cover all 128 rows across the 4 epilogue warps:
for (int sub_m = 0; sub_m < rows_per_warp / 16; sub_m++) {
Next we need a fence to ensure the tcgen05.ld operations have completed before we convert the data to FP16 and store to GMEM:
asm volatile("tcgen05.wait::ld.sync.aligned;");
As discussed earlier this instruction specifically waits for those tcgen05.ld operations to complete.
Then the final loop takes the values in the results registers and stores them to GMEM, in this case we store two FP16 values at a time for better store coalescence:
for (int i = 0; i < TD_MMA_N / 8; i++) {
const int m_offset = m_off + warp_id * rows_per_warp + sub_m * 16 + lane_id / 4;
const int n_offset = n_off + i * 8 + (lane_id % 4) * 2;
reinterpret_cast<half2 *>(c_ref + (m_offset)*N + n_offset)[0] = __float22half2_rn({results[i * 4], results[i * 4 + 1]});
reinterpret_cast<half2 *>(c_ref + (m_offset + 8)*N + n_offset)[0] = __float22half2_rn({results[i * 4 + 2], results[i * 4 + 3]});
}
In this GEMM section we covered how tcgen05 tensor cores and TMEM on the Blackwell architecture can accelerate matrix multiply, including the specifics of the PTX instructions used and basic algorithmic strategies and optimizations. This background will be key as all future kernels will derive from this basic GEMM kernel. Below I've included a list of resources that I used while learning about the Blackwell architecture and optimizing GEMM on GPUs.
GEMM Summary
Unlike the GEMV section, this section doesn't iterate through versioned kernels — the goal was to lay down the foundational tcgen05 / TMA / TMEM machinery that the next two kernels build on. The summary below is therefore organized by the building blocks introduced rather than by version.
Pipeline Stage Summary:
┌────────┬─────────────────┬──────────────────────────────────────┬─────────────────────────────────────────────────────────────┐
│ Stage │ Action │ PTX Instruction │ Notes │
├────────┼─────────────────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────┤
│ 0 │ GMEM -> SMEM │ cp.async.bulk.tensor.{2d,3d} │ TMA load of A/B/SFA/SFB tiles, signals mbarrier on arrive │
├────────┼─────────────────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────┤
│ 1 │ SMEM -> TMEM │ tcgen05.cp.cta_group::N.32x128b │ SF tiles only; A/B stay in SMEM and are read by MMA via desc│
├────────┼─────────────────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────┤
│ 2 │ Compute │ tcgen05.mma.kind::mxf4nvf4.block_sc. │ Issued by single thread, accumulates in TMEM (FP32) │
├────────┼─────────────────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────┤
│ 3 │ TMEM -> Regs │ tcgen05.ld.sync.aligned.16x256b.xN │ Warp-collective, requires 4 warps to cover M=128 │
├────────┼─────────────────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────┤
│ 4 │ Regs -> GMEM │ standard st.global / half2 stores │ FP32 -> FP16 conversion in registers before store │
└────────┴─────────────────┴──────────────────────────────────────┴─────────────────────────────────────────────────────────────┘
Synchronization Primitives Summary:
┌──────────────────────────────────┬──────────────────────────────────────────────────────────────────────────────────────────┐ │ Primitive │ Where it's used │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ mbarrier (TMA) │ Producer (TMA warp) signals via expect_tx; consumer (MMA warp) waits before computing │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ mbarrier (MMA->TMA back-edge) │ Consumer signals when SMEM buffer is free; producer waits to avoid overwriting │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ mbarrier (MMA->Epilogue) │ Signaled via tcgen05.commit once all MMAs in k-loop have completed │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ tcgen05.fence::after_thread_sync │ Orders subsequent tcgen05 ops after prior async tcgen05 work (no reorder) │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ tcgen05.wait::ld │ Blocks executing thread until prior tcgen05.ld ops complete (TMEM -> Regs ready) │ ├──────────────────────────────────┼──────────────────────────────────────────────────────────────────────────────────────────┤ │ elect_one_sync │ Picks one deterministic leader thread in a warp for single-thread PTX ops │ └──────────────────────────────────┴──────────────────────────────────────────────────────────────────────────────────────────┘
Warp Specialization Summary (6-warp kernel, 192 threads):
┌─────────────┬───────────────┬───────────────────────────────────────────────────────────────────────────────────────────────┐ │ Warps │ Specialty │ Work │ ├─────────────┼───────────────┼───────────────────────────────────────────────────────────────────────────────────────────────┤ │ Warp 0 │ TMA │ Issue TMA loads of A/B/SFA/SFB into stage-N SMEM buffers, wait on MMA back-edge mbar │ ├─────────────┼───────────────┼───────────────────────────────────────────────────────────────────────────────────────────────┤ │ Warp 1 │ MMA │ tcgen05.cp scale factors SMEM->TMEM, tcgen05.mma to accumulate result, commit on completion │ ├─────────────┼───────────────┼───────────────────────────────────────────────────────────────────────────────────────────────┤ │ Warps 2-5 │ Epilogue │ Wait for MMA, tcgen05.ld TMEM->Regs, convert FP32->FP16, store to GMEM (4 warps cover M=128) │ └─────────────┴───────────────┴───────────────────────────────────────────────────────────────────────────────────────────────┘
Broad Lessons:
The biggest mental shift relative to pre-tcgen05 GEMM kernels is that the thread is no longer the primary unit of compute. Most tcgen05 operations are issued by a single elected thread per warp, the actual work is performed asynchronously by dedicated hardware units (tensor cores, TMA engine, TMEM), and the warps' job becomes to orchestrate data flow and synchronization between those units. This makes warp specialization a natural fit and decouples how many threads to launch from how much arithmetic work needs to be done.
The hardest part of this section to discover by reading the PTX docs alone was data layout — specifically the "core matrix" concept for SMEM tiles consumed by tcgen05.mma and the (32 x 4) x 4 repeated layout for scale factors in TMEM. Neither is named in the official documentation as a first-class concept (see 07_PTX_lessons.txt). The matrix descriptor's LBO/SBO fields only make sense once the core matrix is understood.
Synchronization correctness is the dominant source of non-deterministic ("Heisen-") bugs in tcgen05 kernels. Producer-consumer mbarriers, fences, and waits must be matched to the exact data flow between the asynchronous units; the hardware does not protect against e.g. an MMA reading SF data that hasn't been copied to TMEM yet or an epilogue reading TMEM before the last MMA committed.
Links and Further Reading
[1] - Modular article [2] - Explanation of Tensor Memory