AI Foundations

Cross-Attention

By Arpit Tripathi, Founder

Cross-attention is an attention mechanism in which the queries come from one sequence while the keys and values come from a different sequence, letting a model condition one set of representations on another, such as a decoder attending to an encoder or text attending to an image.

What is Cross-Attention?

Cross-attention is a form of attention in which one sequence attends to a different sequence. The queries are computed from the first sequence, while the keys and values are computed from the second. This lets the model condition its output on external information: a translation decoder attends to the encoded source sentence, a vision language model lets text attend to image features, and an image generator lets the image being formed attend to a text prompt.

The distinction is simply where the inputs come from. In self-attention, queries, keys, and values all derive from the same sequence, so tokens attend to each other within one input. In cross-attention, queries come from one source and keys and values from another, so the two sequences interact. Both use the same scaled dot-product attention introduced in the 2017 paper 'Attention Is All You Need'.

  • Queries come from one sequence; keys and values from another.
  • Self-attention uses one sequence for all three; cross-attention bridges two.
  • Used in encoder-decoder Transformers, multimodal models, and diffusion models.

How cross-attention works

Cross-attention uses the same scaled dot-product formula as self-attention. The model projects the first sequence into a set of query vectors Q and the second sequence into key vectors K and value vectors V. It scores each query against every key with a dot product, scales by the square root of the key dimension to keep gradients stable, applies a softmax to get attention weights, and uses those weights to take a weighted sum of the value vectors.

The output has one vector per query position, so a cross-attention layer produces a representation aligned to the first sequence but enriched with information pulled from the second. In practice this runs as multi-head attention, where several attention computations operate in parallel on different learned projections and their results are concatenated, letting the model attend to different kinds of relationships at once.

CrossAttention(Q, K, V) = softmax( (Q Kᵀ) / √dₖ ) V
Scaled dot-product attention. In cross-attention Q comes from one sequence while K and V come from another; dₖ is the key dimension.
  • Scores queries against keys, scales, applies softmax, and weights the values.
  • Output length matches the query sequence, content drawn from the other sequence.
  • Runs as multi-head attention to capture several relationship types in parallel.

Cross-attention in code

The snippet below shows a minimal cross-attention layer. Note that the query input and the key/value input are separate arguments, which is the only structural difference from self-attention.

python
import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)

    def forward(self, x_query, x_context):
        # queries from x_query; keys and values from x_context
        out, weights = self.attn(query=x_query,
                                 key=x_context,
                                 value=x_context)
        return out, weights

layer = CrossAttention(dim=512)
text   = torch.randn(1, 20, 512)   # e.g. decoder / text tokens
image  = torch.randn(1, 196, 512)  # e.g. encoder / image tokens
out, _ = layer(text, image)
print(out.shape)                   # torch.Size([1, 20, 512])
A minimal multi-head cross-attention layer using PyTorch.

Where cross-attention is used

Cross-attention appears wherever a model must fuse two information streams. In the original Transformer, each decoder layer has a cross-attention block so the decoder attends to the encoder's representation of the input, which is how machine translation and encoder-decoder speech recognition models like Whisper align output to input. In text-to-image diffusion models, cross-attention injects the text prompt into the image generation process so the picture matches the words.

Multimodal and vision language models use cross-attention to let one modality read another, for example letting language tokens attend to visual features. The mechanism is also what makes attention maps interpretable in these settings, since the attention weights reveal which parts of the source the model focused on for each output position.

  • Encoder-decoder Transformers: the decoder attends to the encoded input.
  • Text-to-image diffusion: the image attends to the text prompt.
  • Multimodal models: one modality attends to another's features.

Key takeaways

  • Cross-attention computes queries from one sequence and keys and values from another, letting one stream condition on the other.
  • It uses the same scaled dot-product attention as self-attention; only the source of the inputs differs.
  • Encoder-decoder Transformers, text-to-image diffusion models, and multimodal models all rely on cross-attention to fuse information.
  • The output length matches the query sequence while drawing content from the second sequence.

Frequently asked questions

Cross-attention is an attention mechanism where the queries come from one sequence and the keys and values come from another. It lets a model condition one representation on a different source, such as a decoder attending to an encoder or text attending to image features.
In self-attention, queries, keys, and values all come from the same sequence, so tokens attend within one input. In cross-attention, queries come from one sequence and keys and values from another, letting two different sequences interact while using the same attention math.
Cross-attention is used in encoder-decoder Transformers (the decoder attends to the encoded input), in text-to-image diffusion models (the image attends to the text prompt), and in multimodal or vision language models where one modality attends to another's features.
It projects one modality into queries and the other into keys and values, scores the queries against the keys, and takes a weighted sum of the values. The result aligns with the query modality but carries information pulled from the second, fusing the two.
Yes. Both use scaled dot-product attention: softmax of Q times K transposed, divided by the square root of the key dimension, multiplied by V. The only difference is that cross-attention draws its keys and values from a separate sequence than its queries.