TitanML is now Doubleword
Doubleword logo black
Product
Resources
Resource CenterAI Dictionary
Docs
Pricing
Book a demo
Book a demo
Resources
/
Technical Guide
/
MLP: Attention in a Trench Coat
March 26, 2025

MLP: Attention in a Trench Coat

Jamie Dborin
Share:
https://doubleword.ai/resources/mlp-attention-in-a-trench-coat
Copied
To Webinar
•

MLP: Attention in a Trench Coat

In the realm of natural language processing, transformer models have revolutionized how we approach language tasks. At the heart of the transformer architecture lie two fundamental operations: attention and multilayer perceptron (MLP). While attention has received significant optimization efforts, the MLP has often been overlooked or considered as something fundamentally different. However, a closer examination reveals striking similarities between these two operations. In this article, I argue that MLPs can be interpreted as a specialized form of attention—essentially "attention in a trench coat"—and leverage this insight to develop more efficient kernels for MLP operations.

The Attention Mechanism: A Quick Refresher

Attention is the mechanism that allows transformers to focus on different parts of the input sequence when producing each output token. In its core implementation, attention involves computing similarities between queries (Q) and keys (K) to create an attention matrix, applying a softmax, and then using these weights to blend values (V).

A schematic representation of the attention mechanism in transformers. The Q and K matrices are multiplied to create an attention matrix, that is multiplied with V. The matrices are aligned such that an output tile is computed by the dot product of the vectors above it and to its left.

‍

The Flash Attention kernel, released by Tri Dao in 2022, revolutionized this process by avoiding the explicit creation of the full N×N attention matrix in GPU global memory. This optimization unlocked training and inference with much longer context lengths by removing the need to fully materialize the O(N²) attention matrix.

The MLP Module

The MLP module in transformers has traditionally been viewed as an operation that acts on tokens independently, in contrast to attention which operates across tokens. A typical MLP consists of:

  1. Projecting input vectors to a larger dimension
  2. Applying a non-linearity
  3. Projecting back to the original dimension

In older models like BERT and GPT-2, this often utilized the ReLU non-linearity:

A schematic representation of a ReLU-based MLP module used in transformer models. The up_proj tensor projects the input to a larger vector space, before performing a non-linearity and projecting it back down using the down_proj tensor.

‍

Modern models like Llama and Qwen use the more sophisticated SwiGLU activation, which incorporates a data-dependent gating mechanism:

A schematic representation of a SwiGLU-based MLP module used in transformer models. Now the non-linearity is data dependent, and is dependent on the activations created by a third tensor, the gate_proj tensor, which acts on the inputs in parallel to the up_proj tensor.

Unveiling the Disguise: MLP as Attention

Looking at these operations side by side reveals something fascinating: MLPs and attention mechanisms are structurally very similar. Let's draw the parallel:

  • In attention, the K^T tensor transforms queries into attention scores
  • In MLP, the up_proj matrix projects inputs into a higher dimension
  • In attention, the V tensor maps attention weights to output values
  • In MLP, the down_proj matrix maps activated values back to output space
  • The softmax in attention and the non-linearity in MLP both serve to modulate the influence of different values

This suggests a reinterpretation of the MLP layer: MLPs are essentially attention mechanisms operating on a fixed KV cache. The key difference is that attention creates dynamic K and V tensors at runtime from the input, whereas MLP layers operate on learned parameters that remain fixed after training.

In other words, an MLP is like an attention mechanism that always attends to the same set of "tokens" (its weight matrices) regardless of the input. It's attention wearing a trench coat, pretending to be something else entirely!

Flash Optimization for MLPs

If MLPs are secretly just attention operations, why don't we apply the same fusion techniques that made Flash Attention so successful? The answer lies in the scaling properties and computational bottlenecks:

  1. Attention's computational complexity scales poorly with sequence length since the K and V tensors grow with input length
  2. MLP operations use fixed-size matrices regardless of sequence length

However, during inference, when we're often memory-bandwidth bound rather than compute-bound, kernel fusion becomes tremendously valuable for MLPs too. By fusing operations, we can avoid redundant memory transfers between GPU global memory and compute units.

‍
Looking at the SwiGLU operation step by step, we can see significant opportunities to reduce memory movement through fusion:

‍

A representation of what operations need to happen during a SwiGLU-MLP. Operations in blue are memory movements, where data is moved to and from GPU global memory. Operations in yellow are the actual computational steps. Fusing the steps together significantly reduces redundant memory movement and leaves mostly the compute steps. This is merely a schematic representation, and does not reflect the time that each step takes!

Building a FlashMLP Kernel

With this reframing of MLPs as attention-like operations, we can adapt the principles from Flash Attention to optimize MLP operations. Just as Flash Attention avoids materializing the full attention matrix, our FlashMLP can avoid materializing large intermediate tensors.

Learning from Flash Attention

We'll use the Triton DSL, the same tool that powers many efficient attention implementations, to build our optimized MLP kernel. Our starting point is PyTorch's flex-attention kernel, which performs competitively with hand-tuned flash attention implementations.

First Attempt: Mimicking Flex-Attention

Our first approach attempts to copy the computation pattern from flex-attention, where query vectors from a single head are loaded and dot products are performed with K and V tensors:

A schematic of the computation performed by a single threadblock in the flex-attention kernel. Squares in blue are loaded to fast GPU shared memory. The Q tensor is kept in shared memory throughout the computation. It is dotted against heads from the K tensor, softmax is applied via online-softmax, and then multiplied by a head of the V tensor, which is accumulated into the output tensor.

However, this approach falls short for MLPs because the inner dimensions are much larger (e.g., 8192 for Llama 8B's hidden dimension versus 64 for the head dimension). This means we cannot keep the entire input vector in fast shared memory during computation:

A schematic of the computation performed by a single threadblock in a flash-ReLU implementation that uses the same compute model as the flex-attention kernel. Note that we cannot keep the input vectors in shared memory because they are too large. We must loop over the inner dimension of the input and up_proj tensors, thereby losing the benefit of keeping the input tensor in fast shared memory.

This results in a much slower kernel than the baseline torch implementation:

‍

The results of the Flash-ReLU implementation above. Plotted is a naive torch implementation against an autotuned, fused triton implementation, done on a 3060 GPU. The hidden and intermediate sizes are set to the same values as used in the llama 8B model, 4092 and 14336 respectively. The batch size is the input dimension of the input tensor. The fused kernel is significantly slower than the native torch implementation. 

‍

Improved Approach: Keeping Activations in Shared Memory

A more effective strategy is to keep the temporary tensors (post-non-linearity) in shared memory longer and perform a dot product across the entire row of the down_proj (V-equivalent) tensor:

A schematic of the computation performed by a single threadblock in a flash-ReLU implementation that uses a compute model that keeps the temporary activation in memory for longer and performs a dot across the entire row of the V tensor.

This modification significantly improves performance, matching the naive implementation at small batch size, but becoming slower at larger batch sizes.

The results of the faster Flash-ReLU implementation above. The Flash-ReLU is about as fast as the native torch implementation at small batch sizes, but becomes slower at larger batch sizes. 

Extending to SwiGLU

With this faster kernel in hand, we can tackle the SwiGLU operation, which requires loading and computing with the gate_proj matrix in addition to up_proj and down_proj:

A schematic of the computation performed by a single threadblock in a flash-SwiGLU implementation that uses a compute model that keeps the temporary activation in memory for longer and performs a dot across the entire row of the V tensor. The big difference between this and the flash-ReLU computation above is the need to also compute the gate activation. We get to reuse the same block from the input to compute both the gate and up tensor activations.

A critical optimization here is that we can reuse the same block of memory from the input tensor to compute both the gate and up tensor activations. This gives our fusion kernel an even greater advantage over naive implementations and the partially-fused SiLU+multiply kernels used in systems like vLLM.

Performance Results

Our FlashSwiGLU implementation consistently outperforms the native PyTorch SwiGLU implementation across various batch sizes:

‍

The results of the Flash-SwiGLU implementation above. The fused-swiGLU is consistently faster than the native torch implementation across all batch sizes. 

What's Next for Flash MLP

This exercise demonstrates that reimagining MLPs as attention mechanisms can lead to meaningful performance improvements through specialized kernels. To make these kernels more broadly useful, several extensions are necessary:

  1. Support for more data types (FP16, BF16) beyond the current FP32 implementation
  2. Auto-tuning for different model architectures, GPU types, and tensor shapes
  3. Integration with quantized models (4-bit, 3-bit, 2-bit parameters)
  4. Compatibility with additional features like token skipping for mixture-of-experts models

Conclusion

By recognizing that MLPs are essentially "attention in a trench coat"—attention mechanisms operating on fixed parameters—we gain new insights into optimizing these crucial components of transformer models. The FlashMLP kernel demonstrates that the same principles that made Flash Attention so successful can be applied to MLP operations, resulting in faster inference and reduced memory usage. As large language models continue to grow in size and importance, these optimizations will help make AI more efficient and accessible across a wider range of hardware.

Footnotes

Table of contents:

Heading 2
Heading 3
Heading 4
Heading 5
Heading 6
Learn more about self-hosted AI Inference
Subscribe to our newsletter
Thanks you for subscription!
Oops! Something went wrong while submitting the form.

Want to learn more?

We work with enterprises at every stage of their self-hosting journey - whether you're deploying your first model in an on-prem environment or scaling dozens of fine-tuned, domain-specific models across a hybrid, multi-cloud setup. Doubleword is here to help you do it faster, easier, and with confidence.

Book a demo
Doubleword logo white
Sitemap
HomePricingDocsResourcesBook a demo
Contact
hello@doubleword.ai
Adress
Farringdon, London
JOIN THE COMMUNITY
Subscribe to our newsletter
Thanks you for subscription!
Oops! Something went wrong while submitting the form.
©2025 Doubleword. All rights reserved.
designed by
celerart
Privacy Policy
We use cookies to ensure you get the best experience on our website.
Accept
Deny