diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f9148db665..1d13483c0c 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -42,6 +42,7 @@ from .create_conv2d import create_conv2d from .create_norm import get_norm_layer, create_norm_layer from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer +from .diff_attention import DiffAttention from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import ( diff --git a/timm/layers/diff_attention.py b/timm/layers/diff_attention.py new file mode 100644 index 0000000000..daa2a67cdc --- /dev/null +++ b/timm/layers/diff_attention.py @@ -0,0 +1,175 @@ +"""Differential Attention + +Paper: 'Differential Transformer' - https://arxiv.org/abs/2410.05258 + +Reference impl: https://github.com/microsoft/unilm/tree/master/Diff-Transformer + +Hacked together by / Copyright 2024, Ross Wightman +""" +import math +from typing import Optional, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .attention import maybe_add_mask +from .config import use_fused_attn +from .norm import RmsNorm + + +class DiffAttention(nn.Module): + """Differential Attention module. + + Computes attention as the difference between two softmax attention maps, which helps + cancel out noise and promotes sparse attention patterns. The module splits Q and K + into two groups, computes separate attention maps, and subtracts one from the other + scaled by a learnable lambda parameter. + + The attention output is computed as: + Attn = softmax(Q1 @ K1^T) - lambda * softmax(Q2 @ K2^T) + Output = Attn @ V + + Supports both fused (scaled_dot_product_attention) and manual implementations. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + scale_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Optional[Type[nn.Module]] = None, + depth: int = 0, + dual_lambda: bool = False, + device=None, + dtype=None, + ) -> None: + """Initialize the DiffAttention module. + + Args: + dim: Input dimension of the token embeddings. + num_heads: Number of attention heads. + qkv_bias: Whether to use bias in the query, key, value projections. + qk_norm: Whether to apply normalization to query and key vectors. + scale_norm: Whether to apply normalization before the output projection. + proj_bias: Whether to use bias in the output projection. + attn_drop: Dropout rate applied to the attention weights. + proj_drop: Dropout rate applied after the output projection. + norm_layer: Normalization layer constructor (defaults to RmsNorm). + depth: Block depth index, used to compute depth-dependent lambda_init. + dual_lambda: If True, use simplified dual scalar lambda parameterization + (2 params). If False, use the paper's original formulation with + lambda_q/k vectors (4 * head_dim params). + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + if norm_layer is None: + norm_layer = RmsNorm + self.num_heads = num_heads + self.head_dim = dim // num_heads // 2 + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_drop_p = attn_drop + self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity() + self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd) + self.proj_drop = nn.Dropout(proj_drop) + + self.dual_lambda = dual_lambda + if dual_lambda: + self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None + else: + self.lambda_a = self.lambda_b = None + self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + + self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd) + + self.lambda_init = 0.8 + self.set_lambda_init(depth) + self.reset_parameters() + + def set_lambda_init(self, depth: int): + self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth) + + def reset_parameters(self): + if self.dual_lambda: + nn.init.zeros_(self.lambda_a) + nn.init.zeros_(self.lambda_b) + else: + nn.init.normal_(self.lambda_q1, mean=0, std=0.1) + nn.init.normal_(self.lambda_k1, mean=0, std=0.1) + nn.init.normal_(self.lambda_q2, mean=0, std=0.1) + nn.init.normal_(self.lambda_k2, mean=0, std=0.1) + + def _compute_lambda(self) -> torch.Tensor: + if self.lambda_a is not None: + lambda_1 = torch.exp(self.lambda_a) + lambda_2 = torch.exp(self.lambda_b) + else: + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()) + return lambda_1 - lambda_2 + self.lambda_init + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, N, C = x.shape + + q, k, v = self.qkv(x).chunk(3, dim=2) + q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + lambda_full = self._compute_lambda().type_as(q) + + if self.fused_attn: + q = q.reshape(B, self.num_heads, 2, N, self.head_dim) + k = k.reshape(B, self.num_heads, 2, N, self.head_dim) + q1, q2 = q.unbind(2) + k1, k2 = k.unbind(2) + + dropout_p = self.attn_drop_p if self.training else 0.0 + attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p) + attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p) + + x = attn1 - lambda_full * attn2 + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = maybe_add_mask(attn, attn_mask) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + attn = attn.view(B, self.num_heads, 2, N, N) + attn = attn[:, :, 0] - lambda_full * attn[:, :, 1] + x = attn @ v + + x = self.sub_norm(x) + x = x * (1 - self.lambda_init) + x = x.transpose(1, 2).reshape(B, N, C) + + x = self.norm(x) + x = self.proj(x) + x = self.proj_drop(x) + + return x diff --git a/timm/models/eva.py b/timm/models/eva.py index 761fee7644..fd7b0b8c3c 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -389,6 +389,7 @@ def __init__( attn_head_dim: Optional[int] = None, device=None, dtype=None, + **kwargs, ): """ Initialize the post-norm EVA transformer block. diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 4e8e691b61..aa4fff0544 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -132,6 +132,7 @@ class NaFlexVitCfg: act_layer: Optional[str] = None # Activation layer for MLP blocks block_fn: Optional[str] = None # Transformer block implementation class name mlp_layer: Optional[str] = None # MLP implementation class name + attn_layer: Optional[str] = None # Attention layer implementation (e.g., 'attn', 'diff') # EVA-specific parameters attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope' @@ -289,13 +290,15 @@ def get_block_fn(cfg: NaFlexVitCfg) -> Callable: else: # Standard ViT block block_fn = cfg.block_fn or Block + block_kwargs = {} if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm: # param names differ between EVA vs non-EVA block types - block_fn = partial( - block_fn, - scale_mlp_norm=cfg.scale_mlp_norm, - scale_attn_norm=cfg.scale_attn_inner_norm - ) + block_kwargs['scale_mlp_norm'] = cfg.scale_mlp_norm + block_kwargs['scale_attn_norm'] = cfg.scale_attn_inner_norm + if cfg.attn_layer: + block_kwargs['attn_layer'] = cfg.attn_layer + if block_kwargs: + block_fn = partial(block_fn, **block_kwargs) return block_fn @@ -1214,6 +1217,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, + depth=i, **dd, ) for i in range(cfg.depth) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1a31d8ddd5..b0d1b24298 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -47,6 +47,7 @@ ) from timm.layers import ( Attention, + DiffAttention, AttentionPoolLatent, PatchEmbed, Mlp, @@ -79,6 +80,49 @@ _logger = logging.getLogger(__name__) +ATTN_LAYERS = { + '': Attention, + 'attn': Attention, + 'diff': DiffAttention, +} + + +def _create_attn( + attn_layer: LayerType, + dim: int, + num_heads: int, + qkv_bias: bool = False, + qk_norm: bool = False, + scale_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Optional[Type[nn.Module]] = None, + depth: int = 0, + **kwargs, +) -> nn.Module: + if isinstance(attn_layer, str): + attn_layer = ATTN_LAYERS.get(attn_layer, None) + assert attn_layer is not None, f'Unknown attn_layer: {attn_layer}' + + # Only pass depth to attention layers that use it + if issubclass(attn_layer, DiffAttention): + kwargs['depth'] = depth + + return attn_layer( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + scale_norm=scale_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + **kwargs, + ) + + class Block(nn.Module): """Transformer block with pre-normalization.""" @@ -99,6 +143,8 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, + attn_layer: LayerType = Attention, + depth: int = 0, device=None, dtype=None, ) -> None: @@ -118,12 +164,15 @@ def __init__( act_layer: Activation layer. norm_layer: Normalization layer. mlp_layer: MLP layer. + attn_layer: Attention layer type (class or string). + depth: Block index, passed to attention layer for depth-dependent init. """ super().__init__() dd = {'device': device, 'dtype': dtype} self.norm1 = norm_layer(dim, **dd) - self.attn = Attention( + self.attn = _create_attn( + attn_layer, dim, num_heads=num_heads, qkv_bias=qkv_bias, @@ -133,7 +182,8 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, - **dd + depth=depth, + **dd, ) self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -175,14 +225,17 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, - device = None, - dtype = None, + attn_layer: LayerType = Attention, + depth: int = 0, + device=None, + dtype=None, ) -> None: super().__init__() dd = {'device': device, 'dtype': dtype} self.init_values = init_values - self.attn = Attention( + self.attn = _create_attn( + attn_layer, dim, num_heads=num_heads, qkv_bias=qkv_bias, @@ -192,6 +245,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + depth=depth, **dd, ) self.norm1 = norm_layer(dim, **dd) @@ -246,7 +300,9 @@ def __init__( drop_path: float = 0., act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, - mlp_layer: Optional[Type[nn.Module]] = None, + mlp_layer: Optional[Type[nn.Module]] = None, # not used + attn_layer: Optional[LayerType] = None, # not used + depth: int = 0, # not used device = None, dtype = None, ) -> None: @@ -351,8 +407,10 @@ def __init__( act_layer: Type[nn.Module] = nn.GELU, norm_layer: Type[nn.Module] = LayerNorm, mlp_layer: Type[nn.Module] = Mlp, - device = None, - dtype = None + attn_layer: LayerType = Attention, + depth: int = 0, + device=None, + dtype=None, ) -> None: dd = {'device': device, 'dtype': dtype} super().__init__() @@ -362,7 +420,8 @@ def __init__( for _ in range(num_parallel): self.attns.append(nn.Sequential(OrderedDict([ ('norm', norm_layer(dim, **dd)), - ('attn', Attention( + ('attn', _create_attn( + attn_layer, dim, num_heads=num_heads, qkv_bias=qkv_bias, @@ -372,6 +431,7 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + depth=depth, **dd, )), ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()), @@ -482,6 +542,7 @@ def __init__( act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, + attn_layer: LayerType = Attention, device=None, dtype=None, ) -> None: @@ -592,6 +653,8 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, + attn_layer=attn_layer, + depth=i, **dd, ) for i in range(depth)])