evaluates whether frontier LLMs can implement high-performance GPU kernels from a mathematical description. no hand-holding, no fill-in-the-blank templates, no pytest. just: here's the math, here's a GPU, make it work and make it fast.
general·flash-attentiontritongpu-kernelsattentionoptimization·1x H100·updated 1h ago
9This requires materializing an (N × N) attention matrix, where N is the sequence
10length. At N=16384 with 16 heads in bfloat16, that matrix alone is **8 GB** —
11making it impractical for long sequences.
12
13Your job: **implement attention without materializing the full N×N matrix, then optimize it to the absolute limit of the hardware.**
14
15You have an **NVIDIA H100 GPU**. You will write code, test it on
16real hardware, and optimize relentlessly until you run out of time.
17
18## Interface
19
20Edit **`/app/repo/solution.py`**. Implement a single function:
21
22```
23def efficient_attention(Q, K, V, is_causal=False) -> O
24```
25
26- **Q**: `(B, Nq, D)` — queries, float32 or bfloat16, CUDA
27- **K**: `(B, Nk, D)` — keys
28- **V**: `(B, Nk, D)` — values
29- **is_causal**: bool — when True, mask positions where query_index < key_index
30- **O**: `(B, Nq, D)` — output, same dtype as Q
31
32Must support `torch.autograd` (backward pass must produce correct dQ, dK, dV).
33
34You are free to use **any approach**: Triton kernels, `torch.compile`, custom
35autograd functions, whatever works. You may create additional `.py` files in
36`/app/repo/` and import from them.
37
38## Constraints
39
401. Must be **memory-efficient**: run at B=16, N=16384, D=64, bf16 on a single H100 without OOM.
412. Must be **correct**: match naive attention output within rtol=1e-2, atol=1e-2.
423. Must support **backward pass**: gradients dQ, dK, dV must be correct.
434. Must support **causal masking**.
445. **You must write your own implementation from scratch.** You may use PyTorch and Triton as building blocks, but the attention algorithm itself must be yours.
45
46## Anti-Cheat Policy (READ THIS)
47
48The verifier performs comprehensive cheat detection. If ANY of the following are
49detected, your score is **instantly zero** with no partial credit:
50
show all 396 lines (+346 more)collapse
51**Banned libraries and functions:**
52- `torch.nn.functional.scaled_dot_product_attention` (including via `F.scaled_dot_product_attention`)
53- `flash_attn` library (including any submodule)
242kernel launch overhead via persistent kernel (loop over Q tiles in one launch)
243should save ~0.5ms from grid scheduling overhead."
244
245#### Step 3: Implement the change
246
247Make ONE change at a time. Not five. ONE.
248
249#### Step 4: Measure and log
250
251```
252bash /app/tools.sh bench --quick
253bash /app/tools.sh ncu
254```
255
256Compare numbers before and after. Was your hypothesis correct? Write it down.
257If the change regressed performance, **revert it immediately** (`git checkout -- solution.py`).
258
259#### Repeat until time runs out.
260
261**Optimizations to explore (ordered by typical impact):**
262
2631. **exp2/log2 instead of exp/log** — hardware `exp2` is 2x faster on H100 tensor cores. scale by `1/log(2)`.
2642. **Separate fwd/bwd benchmarks** — find which is the bottleneck. backward is usually 2-3x slower.
2653. **Persistent kernels** — instead of 1 tile per program, loop over multiple tiles per SM. eliminates grid launch overhead and improves SM utilization. this is doable in triton 3.1.0 with a `while` loop.
2664. **Swizzled tile scheduling** — remap `program_id` to a Z-order or swizzled pattern for better L2 cache locality across thread blocks.
2675. **Split-K backward** — partition the K dimension across thread blocks for the dQ kernel, then reduce. helps when N >> BLOCK_N.
2686. **Tile size tuning** — MUST be guided by ncu occupancy and shared memory data, not random sweeps.
2697. **Warp count tuning** — 4 vs 8 warps changes occupancy and register pressure. ncu tells you which is limiting.
2708. **Pipeline stages** — `num_stages` controls software pipelining depth. 1-3 only on triton 3.1.0 (4 crashes).
271
272**DO NOT do random parameter sweeps.** Every config change must be motivated
273by profiling data and roofline math. If you can't explain WHY a particular
274block size should be better based on arithmetic intensity, kernel timing, and
275H100 peak throughput, don't try it.
276
277### Phase 3: Submit (do this early, then keep optimizing)
278
279Submit your current best **as soon as correctness is established**, then keep
280optimizing and re-submitting. Do not wait until the end.
367| backward correctness | 0.10 | gradients match reference |
368| causal masking | 0.05 | correct with is_causal=True |
369| memory efficiency | 0.05 | runs at N=16384 without OOM, doesn't materialize N×N |
370| timing | 0.70 | fwd+bwd latency — **lower is better, no cap** |
371
372Timing score = {{reference_timing_ms}} / your_ms. **There is no ceiling — faster is always better.**
373Correctness is pass/fail. Timing requires all correctness checks to pass.
374
375**Correctness alone is worth almost nothing. Performance is everything.**
376A correct implementation at 80ms is a failure. A correct implementation at 5ms is exceptional.
377
378## Tips
379
380- **Run `triton_smoke.py` first.** Verify Triton works before writing flash attention.
381- **Profile before optimizing.** `bash /app/tools.sh profile` shows per-kernel CUDA time. Compute roofline numbers. Don't guess.
382- **ncu may be available** depending on your environment. Run `bash /app/tools.sh ncu` — it auto-detects access (tries sudo if needed) and falls back to torch.profiler if ncu can't run. On bare-metal servers ncu gives hardware counters (occupancy, throughput, stall reasons). On cloud sandboxes use torch.profiler + manual roofline math.
383- **Start simple.** Get a correct implementation first, then immediately start profiling.
384- **Read the papers.** The references in `reference.md` describe efficient approaches. FlashAttention-2 paper has the key algorithmic insights.
385- **float32 accumulators.** Accumulating in lower precision will cause precision errors in softmax and dot products.
386- **Separate fwd and bwd optimization.** Use `--forward-only` and `--backward-only` to find which is the bottleneck and focus there.
387- **The backward pass is typically 2-3x slower than forward.** If your backward is >2x your forward, there's room to improve it.
389- **Think about data reuse.** Each byte loaded from HBM should do as much compute as possible. Larger tiles = more reuse = closer to compute-bound.
390
391## Known Triton 3.1.0 Limitations on Hopper (H100)
392
393- **`num_stages=4` crashes the compiler** ("SharedEncodingAttr builder when MMAEncodingAttr is Hopper has not been implemented yet"). Use `num_stages <= 3`.