Shortcuts

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Multi-Head Attention is defined as:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

forward() will use the optimized implementation described in FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness if all of the following conditions are met:

  • self attention is being computed (i.e., query, key, and value are the same tensor. This restriction will be loosened in the future.)

  • inputs are batched (3D) with batch_first==True

  • Either autograd is disabled (using torch.inference_mode or torch.no_grad) or no tensor argument requires_grad

  • training is disabled (using .eval())

  • add_bias_kv is False

  • add_zero_attn is False

  • batch_first is True and the input is batched

  • kdim and vdim are equal to embed_dim

  • if a NestedTensor is passed, neither key_padding_mask nor attn_mask is passed

  • autocast is disabled

If the optimized implementation is in use, a NestedTensor can be passed for query/key/value to represent padding more efficiently than using a padding mask. In this case, a NestedTensor will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.

Parameters:
  • embed_dim – Total dimension of the model.

  • num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

  • dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

  • bias – If specified, adds bias to input / output projection layers. Default: True.

  • add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

  • add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

  • kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

  • vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source]
Parameters:
  • query (Tensor) – Query embeddings of shape (L,Eq)(L, E_q) for unbatched input, (L,N,Eq)(L, N, E_q) when batch_first=False or (N,L,Eq)(N, L, E_q) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EqE_q is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

  • key (Tensor) – Key embeddings of shape (S,Ek)(S, E_k) for unbatched input, (S,N,Ek)(S, N, E_k) when batch_first=False or (N,S,Ek)(N, S, E_k) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EkE_k is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

  • value (Tensor) – Value embeddings of shape (S,Ev)(S, E_v) for unbatched input, (S,N,Ev)(S, N, E_v) when batch_first=False or (N,S,Ev)(N, S, E_v) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EvE_v is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

  • key_padding_mask (Optional[Tensor]) – If specified, a mask of shape (N,S)(N, S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S)(S). Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding key value.

  • need_weights (bool) – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

  • attn_mask (Optional[Tensor]) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L,S)(L, S) or (Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are supplied, their types should match.

  • is_causal (bool) – If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. Default: False.

  • average_attn_weights (bool) – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

Return type:

Tuple[Tensor, Optional[Tensor]]

Outputs:
  • attn_output - Attention outputs of shape (L,E)(L, E) when input is unbatched, (L,N,E)(L, N, E) when batch_first=False or (N,L,E)(N, L, E) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EE is the embedding dimension embed_dim.

  • attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape (L,S)(L, S) when input is unbatched or (N,L,S)(N, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. If average_attn_weights=False, returns attention weights per head of shape (num_heads,L,S)(\text{num\_heads}, L, S) when input is unbatched or (N,num_heads,L,S)(N, \text{num\_heads}, L, S).

Note

batch_first argument is ignored for unbatched inputs.

merge_masks(attn_mask, key_padding_mask, query)[source]

Determine mask type and combine masks if necessary. If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape (batch_size, num_heads, seq_len, seq_len), combined with logical or and mask type 2 will be returned :param attn_mask: attention mask of shape (seq_len, seq_len), mask type 0 :param key_padding_mask: padding mask of shape (batch_size, seq_len), mask type 1 :param query: query embeddings of shape (batch_size, seq_len, embed_dim)

Returns:

merged mask mask_type: merged mask type (0, 1, or 2)

Return type:

merged_mask

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources