Back to blog listing

No cuBLAS: A Hand-Written CUDA BERT That Matches PyTorch


The honest framing up front: this is a learning project, and “beating PyTorch” needs an asterisk I will get to. What I set out to do was write every kernel of a real model by hand in CUDA, no cuBLAS and no CUTLASS, and find out how far hand-written code lands from the library that has had a decade of tuning. The answer was closer than I expected.

The model I chose is all-MiniLM-L6-v2, a 6-layer BERT that turns a sentence into a 384-dimensional embedding. Small, fixed shapes, no generation loop. Why this model? No particular reason. It’s fairly popular on HF still, and it’s one I spent a great deal of time experimenting with earlier this year doing some fine tuning passes at work. So, why not?

Hardware

The hardware is one RTX 3090. PyTorch on the GPU is cuBLAS underneath, so this is really a question about how close I can get to cuBLAS by hand. This ended up mattering more than I expected. I figured the difference between something like an H100 and a consumer card would feel transparent at the level of CUDA, mostly just putting together operations and maybe the api/hardware would handle a specific op differently based on available hardware capabilities. I guess this is mostly true, but I didn’t expect myself to end up down a rabbit hole reading about how consumer cards are actually nerfed intentionally. I knew that there was something going on around export controls, I guess I just assumed that this was limited to commercial grade hardware and not consumer cards.

As a side note, I can’t help but wonder now if game engines really take advantage of this stuff. I’d guess they do. May be another side quest in my near future. Take a real rendering pass and optimize for my specific card.

What the forward pass computes

The whole model is six identical layers between an embedding lookup and a pooling step. The shapes never change, so I baked them in as compile-time constants and let the compiler unroll against them.

// all-MiniLM-L6-v2, fixed at compile time
constexpr int HIDDEN = 384;
constexpr int NUM_LAYERS = 6;
constexpr int NUM_HEADS = 12;
// 12 heads * 32 = 384
constexpr int HEAD_DIM = 32;
// 4 * HIDDEN, the feed-forward width
constexpr int FFN_DIM = 1536;
// max positions; the runtime can shrink this, which matters later
constexpr int SEQ_LEN = 128;

The input is a tokenized sentence: input_ids and an attention mask, both (batch, seq) integers. The output is one (batch, HIDDEN) vector per sentence. The top level is short:

// ids, types, mask are (batch * seq); emb is the (batch, HIDDEN) output
void bert_embed(Workspace &ws, const ModelWeights &w, const int32_t *ids,
                const int32_t *types, const int32_t *mask, __half *emb,
                int batch, int seq) {
  // embeddings + 6 encoder layers -> hidden states (batch * seq, HIDDEN)
  bert_encode(ws, w, ids, types, mask, ws.hidden, batch, seq);
  // average the unmasked token vectors -> (batch, HIDDEN)
  launch_mean_pool(ws.hidden, mask, emb, seq, HIDDEN, batch);
  // L2-normalize each row so cosine similarity is a dot product
  launch_l2_normalize(emb, HIDDEN, batch);
}

Inside bert_encode, the six layers ping-pong between two buffers so nothing reallocates:

// word + position + token-type lookup, then a layernorm
launch_embedding(ids, types, w.word, w.pos, w.type, ws.summed, seq, HIDDEN, batch);
launch_layernorm(ws.summed, w.emb_ln_w, w.emb_ln_b, cur, batch * seq, HIDDEN, eps);
// every layer reads cur and writes nxt, then they swap
for (int i = 0; i < NUM_LAYERS; i++) {
  encoder_layer(ws, cur, w.layer(i), mask, nxt, batch, seq);
  std::swap(cur, nxt);
}

Everything in fp16. The whole rest of the post is making those kernels fast.

Memory: allocate once

The first thing that is easy to get wrong is allocation. Every buffer the forward needs is sized from the constants, so I allocate one Workspace up front and reuse it for every call and every layer. No cudaMalloc in the hot path. This sounds obvious but it was worth a measurable chunk, and more importantly it is what makes a persistent, no-surprises pipeline possible.

// q/k/v projection, scores, attention output, ffn scratch, ping-pong buffers
struct Workspace {
  DeviceBuffer qkv, scores, merged, attn_proj;
  DeviceBuffer inter, ffn_proj, hidden, summed, ping, pong;
};

The matmul is the whole ballgame

Every dense layer is a matrix multiply, and profiling with nsys put 70% of the time there, the same as PyTorch. So the matmul is the project. One convention to get out of the way: PyTorch stores a linear layer’s weight as (out_features, in_features) and computes input @ weight^T. So my kernel computes C = A @ B^T, and because a row-major (N, K) matrix is exactly B^T read as column-major (K, N), the weight feeds the tensor cores with no transpose.

The naive version is one thread per output element looping over K, reading from global memory each time. 7 milliseconds per sentence. Shared-memory tiling, where a block stages a tile of each input and reuses it, took that to 2.6. Then the tensor cores, which the 3090 has dedicated hardware for. The wmma API works on 16x16 fragments:

// A tile feeds matrix_a; B row-major is B^T as col-major, feeds matrix_b
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> b;
// accumulate in fp16 (more on why later)
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> c;
wmma::fill_fragment(c, __float2half(0.0f));
// walk K in steps of 16, multiplying and accumulating tiles
for (int k0 = 0; k0 < K; k0 += 16) {
  wmma::load_matrix_sync(a, A + tile_row * K + k0, K);
  wmma::load_matrix_sync(b, B + tile_col * K + k0, K);
  wmma::mma_sync(c, a, b, c);
}

That alone reached 0.91 ms for a single sentence. The biggest jump after it was cp.async, the Ampere instruction that copies global to shared without going through registers and without blocking the warp. You double-buffer: while the tensor cores work the current slice, the next one is already in flight.

// issue the async copy of the next K-slice into the other shared buffer
__pipeline_memcpy_async(&As[next][i], &A[...], 16);
__pipeline_commit();
// wait until only the previous copy is outstanding, then compute on it
__pipeline_wait_prior(1);
__syncthreads();

The last GEMM gain was register blocking: each warp computes a 64x32 output tile, holding eight accumulator fragments in registers at once, which raises the work done per shared-memory load. Together these turned the batched matmul from 0.225 down to about 0.10 ms per sentence.

Linear, with the epilogue doing extra work

A “linear” is the matmul plus a bias add, and sometimes a gelu. Both are cheap elementwise ops, and the wrong move is to launch a separate kernel for each. I fold them into the matmul’s epilogue, the code that runs once on each output as it is written out, while the value is still in a register.

// out = x @ w^T + bias, gelu applied in the same kernel when requested
launch_matmul(x, w, out, m, n, k, bias, /*gelu=*/true);
// runs per output element, in registers, before the single store
__device__ float epilogue(float v, const __half *bias, int col, bool gelu) {
  if (bias)
    v += __half2float(bias[col]);
  // exact gelu, the same one torch uses
  if (gelu)
    v *= 0.5f * (1.0f + erff(v * 0.70710677f));
  return v;
}

That deletes a full kernel launch per feed-forward block. The same idea applies to the residual add: instead of a separate pass, the layernorm kernel takes the residual as an input and adds it while it already has the row loaded.

Embeddings

The embedding step is three table lookups summed per token. The only subtlety is that position resets at each sequence boundary in a batch.

// idx ranges over (batch * seq * HIDDEN); recover which token and channel
int row = idx / HIDDEN;
int h = idx % HIDDEN;
// position is the offset within this sentence, not the global row
int pos = row % seq;
// word vector + position vector + token-type vector
float v = word[ids[row] * HIDDEN + h] + posemb[pos * HIDDEN + h] +
          type[types[row] * HIDDEN + h];
out[idx] = __float2half(v);

Attention, without the bookkeeping

Attention is where the layout work pays off. The textbook implementation projects the input three times (query, key, value), splits each into 12 heads, runs scaled dot-product attention per head, then merges the heads back. Those split and merge steps are pure data shuffling, and I removed them entirely.

First, the three projections become one. I concatenate the Q, K, and V weight matrices at load time, so a single matmul produces all of QKV stacked along the column dimension, shape (batch * seq, 3 * HIDDEN).

// one projection instead of three: qkv is (batch*seq, 3*HIDDEN)
linear(hidden, w.qkv_w, w.qkv_b, ws.qkv, batch * seq, 3 * HIDDEN, HIDDEN);

Then the score kernel reads each head’s query and key straight out of that buffer with a strided tensor-core load, no copy into a per-head buffer. The row stride is 3 * HIDDEN and a column offset picks the Q slice or the K slice.

// block z is the global head; b is the batch index, hl the head within it
int b = z / NUM_HEADS, hl = z % NUM_HEADS;
// point at this head's Q inside the fused buffer: column q_off, head hl
const __half *q = qkv + (b * seq) * qkv_stride + q_off + hl * HEAD_DIM;
// load a 16x16 Q tile with the fused row stride
wmma::load_matrix_sync(af, q + row * qkv_stride + d0, qkv_stride);

Scores come out as (batch * heads, seq, seq). The softmax is one warp per row using shuffle reductions, and it folds the mask in directly: a padding key never enters the max or the sum, so it contributes zero weight without a separate masking pass.

// each warp owns one row of the score matrix
float m = -1e30f;
for (int j = lane; j < seq; j += 32)
  // skip padding keys entirely
  if (mask[j])
    m = fmaxf(m, __half2float(row[j]));
// reduce the max across the warp with shuffles, no shared memory
m = warp_max(m);

The context step multiplies those weights by V and writes the result straight into the merged (batch * seq, HIDDEN) layout that the output projection expects. No merge kernel. With the projection fused, the split and merge gone, and scores never leaving shared registers longer than they must, the attention block is four launches instead of nine.

Layernorm and pooling

Layernorm is a per-row reduction, and like the softmax it runs one warp per row with shuffles, no shared memory and no block barrier. It also folds in the residual add for free.

// mean and variance over the row, computed with warp-shuffle reductions
float mean = warp_sum(sum) / float(D);
float rstd = rsqrtf(warp_sum(sq) / float(D) + eps);

Pooling is the mean of the unmasked token vectors, which is the sentence embedding before normalization.

// average only the real tokens; padding rows have mask 0
float denom = count > 0 ? float(count) : 1.0f;
out[d] = __float2half(acc / denom);

The encoder layer, assembled

With those pieces, a layer is just two blocks, and the whole forward is six of them.

void encoder_layer(Workspace &ws, const __half *hidden, const LayerWeights &w,
                   const int32_t *mask, __half *out, int batch, int seq) {
  // self-attention sublayer, writes into ws.attn_out
  attention_block(ws, hidden, w.attn, mask, ws.attn_out, batch, seq);
  // feed-forward sublayer, writes the layer output
  ffn_block(ws, ws.attn_out, w.ffn, out, batch, seq);
}

Where it ended up

Everything below is measured. I kept a parity test against the PyTorch reference for every kernel, so each speedup is a real change with the embedding still matching at cosine 0.9999.

Batched throughput, 64 sentences at a time, same input:

sequence lengthmine (ms/sentence)PyTorch (ms/sentence)
128 (padded)0.0900.092
16 (the real length)0.0150.025

A tie at full length, about 1.7x faster at the real length of the test sentence. The climb that got there, each step a single commit, measured at batch 64:

stepms/sentencethe idea
batched baseline (simple wmma)0.225one warp per 16x16 tile
cp.async pipelined matmul0.140overlap load with math
fuse QKV, drop split/merge0.118one projection, read it directly
warp-per-row layernorm0.109shuffle reduction
fuse gelu into the matmul0.102one less kernel per layer
register-block attention scores0.101bigger warp tile
fp16 accumulate0.0912x tensor rate, parity
skip the padding0.015process the real length

The trick that closed the gap with cuBLAS

For a long time my GEMM sat about 13% behind, and the thing that closed it is specific to consumer GPUs. A GeForce card runs fp16 * fp16 -> fp16 tensor operations at twice the rate of fp16 * fp16 -> fp32. It is a deliberate cap to protect the datacenter parts. cuBLAS sensibly accumulates in fp32 by default, because for a general library that is the safe choice. But this is a 6-layer model whose output I compare by cosine, and it tolerates fp16 accumulation fine, end-to-end cosine stayed at 0.9999. The change is one line, the accumulator fragment in the snippet above going from float to __half.

That is the lever. I did not out-engineer cuBLAS. I made a precision tradeoff it declines to make by default, the tradeoff happens to be free for this task, and it bought back exactly the rate PyTorch was leaving on the table.

The things that did not work

This is the part I would have wanted to read before starting, because I burned the most time here.

Hand-blocked shared-memory GEMM. Twice. The textbook move after the simple tensor-core kernel is to stage big tiles of both inputs into shared memory for the whole block to reuse. Both times it was slower than the simple one-warp-per-tile version. The reason is L2: the simple kernel’s “redundant” global reads are mostly cache hits, so it already has the reuse, and explicit staging just adds barriers and cuts occupancy.

Deep software pipelines. cuBLAS uses 4 and 5-stage pipelines. My correct 3-stage version was slower than the 2-stage double buffer on every shape. My K is small, 384 or 1536, so two buffers already hide the latency, and more stages spend shared memory and occupancy for nothing. The deep pipelines are tuned for the general case, not for short, fat matmuls.

Fused flash attention. I wrote the obvious version: one block per head, keep the whole score matrix in shared memory, do scores then softmax then the value product in one kernel so the scores never touch global memory. Slower. The score matrix is 32 KB of shared memory, which drops occupancy to one block per SM, and the three serialized phases inside the block lose more than the global traffic they save, which L2 was caching anyway. The flash kernels that win, like the one PyTorch calls, block over the keys and run an online softmax specifically to keep occupancy up.

The pattern across all three: my intuition said “less memory traffic is better,” and on these small, fixed shapes the real constraint was occupancy and the cache, not bandwidth. Profiling told me this every time and I kept needing to relearn it.

The asterisk, and why short sentences are honest

The 1.7x at the real length needs care. The test sentence is 9 tokens, and the seq=128 benchmark forces both implementations to pad to 128 and process 119 tokens of nothing. Those padding tokens get masked out in attention and averaged away in pooling, so the embedding is identical whether you run 16 tokens or 128. Skipping them is correct, and any real inference engine does it. sentence-transformers itself pads to the longest item in a batch, not to 128. I made seq a runtime argument so the forward can process the real length, rounded up to the 16 the tensor cores want.

// the sentence is 9 real tokens; round to the wmma tile and skip the rest
int eff = ((n_real + 15) / 16) * 16;
bert_embed(ws, w, ids, types, mask, emb, batch, eff);

So I built a second benchmark that runs PyTorch at every length, and PyTorch gets the same speedup from processing fewer tokens. The length win is a “do less work” win, and it helps both engines. It is not me beating cuBLAS.

But there is a real result hiding in it. At seq=16, my forward is 0.015 ms and PyTorch is 0.025. At that size the arithmetic is tiny, so the forward is overhead-bound, kernel launches and dispatch and fixed per-call costs, and there I am genuinely leaner. I fuse gelu, bias, and the residual into other kernels, while PyTorch launches each separately and pays Python dispatch on every op. Fewer kernels, no framework tax, fixed shapes, one preallocated workspace. When the math gets small enough that overhead is the bill, hand-written wins. So it may be a bit gimmicky, but I still win! First place is first place baby, pick a top tier!

One tiled sentence is a weak benchmark, so I ran 2000 real sentences from the STS benchmark, length-sorted with each batch padded to its longest, exactly what sentence-transformers does. Forward only, tokenization outside the timed region on both sides. My engine does 54,000 sentences per second, PyTorch does 39,000, and the embeddings match at cosine 1.00000. The sentences average 15 tokens, so this is the overhead-bound regime, and the leaner pipeline wins by the margin you would expect. Padding the whole corpus to 128 instead drops me to 11,000, a clean measurement of the cost of the padding the effective-length path skips.

That is the actual conclusion, and it is more specific than “I beat PyTorch.” At full length I reach parity with cuBLAS by spending a precision tradeoff it declines to make. At short length I pull ahead on plumbing, because a hand-written pipeline has less of it. The GEMM itself, at equal precision, I could not beat, and the failed experiments are the evidence for why: that gap is the multi-stage, autotuned, swizzled machinery that is the entire point of CUTLASS, and reproducing it by hand is the project, not a tweak.

What I really learned

I don’t love this kind of work. It’s alright, I could see doing it for a while without losing my mind. It’s certainly more interesting than some things I’ve worked on at the dayjob throughout my career. CUDA in particular feels pretty crusty. I’m way more interested in things like Mojo or my own project cljrs or Parrot (which I’ve actually landed some contributions to recently). I’ve been a programming language nerd for a long time, and CUDA is just not nearly as expressive as what I would want on a day to day basis. On one hand I want to be able to have full control of performance and not leave anything on the table.

There just has to be something better than this. I lean towards Mojo, and one of my goals is to rewrite this project using Mojo here soon. Although before that I’ll probably try and throw claude at something that ports this to Rust, Swift, or some other languages with CUDA support to see how the perf compares with a naive translation. My biggest fear is that I just burned a month digging into CUDA and here soon it will all leave my brain just as quickly as it arrived. I’ve always been maximally “use it or lose it”, and I don’t see myself obsessively writing CUDA non-stop in the near future. I already have a to do a triple buffer internally to remember wtf fmaxf is and I wrote my last commit just a few hours ago. Six months from now I’ll have to google what GEMM even means (I’m exaggerating but this feels somewhat true).

Was it worth it? Maybe. I think it’s a neat project, and if nothing else I proved I could get it done if I had to. Although if I had to do this for work I would probably just slam it with agents until the profiling step, and I’d be using libraries anyways which would make this a ton easier. I think it’s an interesting foray into optimizing for a specific card. In particular I feel like I could write a forward pass for Qwen-Coder models or something if I needed to really maximize my throughput or something in the near future. In general though, most of the mature codebases like Pytorch or Jax or whatever can’t really take most of the approaches I’ve taken advantage of here. They JIT compile their code, and I assume that means they take something about the system into account. It’s trivial to prove that for a particular inference pass you can do better (I just did that), but if you need to support many different pieces of hardware, it’s not an option.

I also fully expect someone to put up a PR at some point showing that with 3 lines of Pytorch code you can actually get another 100x speed up. So yeah, the time spent was maybe a bit dubious.

The code is on GitHub.