Drawing the flash attention animation
Some time ago, I was studying how Flash Attention works.
The main material available is the pyramid visualization from the Flash Attention paper.
I wanted the visualization to be animated. I did not manage to find a good resource out there 1. So I made mine.
The animation was made with MacOS Keynote, manually. 2. As of writing, it appears to rank second on Google search.
I also wrote a Quora answer explaining how Flash Attention works.
I asked in the GPU Mode Discord for opinions on my work.
gau.nernst replied with the following comments, which I greatly appreciate.
i wanted to comment that the loop order was reversed. but upon checking, turns out FA1 used this loop ordering, but FA2 reversed it (and I only read the FA2 paper lmao)
so in FA2, iterating along K/V is the inner loop, iterating along Q/O is the outer loop, which is implemented as 1 threadblock handling 1 Q/O tile
yea I think FA3 and FA4 also follow the FA2’s general design, but optimized for Hopper and Blackwell respectively
The explanation may not be complete, the details may not be fully correct, but I still hope this makes it slightly easier for you to understand Flash Attention.