Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType, disable_compiler
from .weight_init import (
is_meta_device,
trunc_normal_,
trunc_normal_tf_,
variance_scaling_,
Expand Down
28 changes: 19 additions & 9 deletions timm/layers/blur_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,24 @@ def __init__(
self.stride = stride
self.pad_mode = pad_mode
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
self.register_buffer(
'filt',
self._create_blur_filter(device=device, dtype=dtype),
persistent=False,
)

def _create_blur_filter(self, device=None, dtype=None) -> torch.Tensor:
"""Create the blur filter tensor."""
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
coeffs = torch.tensor(
[comb(filt_size - 1, k) for k in range(filt_size)],
[comb(self.filt_size - 1, k) for k in range(self.filt_size)],
device='cpu',
dtype=torch.float32,
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
) / (2 ** (self.filt_size - 1)) # normalise so coefficients sum to 1
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
if channels is not None:
if self.channels is not None:
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)

self.register_buffer(
'filt',
blur_filter.to(device=device, dtype=dtype),
persistent=False,
)
return blur_filter.to(device=device, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, mode=self.pad_mode)
Expand All @@ -73,6 +75,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.filt
return F.conv2d(x, weight, stride=self.stride, groups=channels)

def init_non_persistent_buffers(
self,
device: torch.device,
dtype: torch.dtype,
) -> None:
"""Initialize non-persistent buffers."""
self.register_buffer('filt', self._create_blur_filter(device=device, dtype=dtype), persistent=False)


def _normalize_aa_layer(aa_layer: LayerType) -> Callable[..., nn.Module]:
"""Map string shorthands to callables (class or partial)."""
Expand Down
13 changes: 13 additions & 0 deletions timm/layers/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,16 @@ def forward(self, x):
out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = self.pool(out)
return out

def init_non_persistent_buffers(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Initialize non-persistent buffers."""
if self.rel_pos_indices is None:
return
device = device or self.qkv.weight.device
# Compute feat_size from pos_emb shape: rel_size = 2 * feat_size - 1
feat_size = ((self.pos_emb.shape[0] + 1) // 2, (self.pos_emb.shape[1] + 1) // 2)
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size, device=device), persistent=False)
46 changes: 46 additions & 0 deletions timm/layers/pos_embed_rel.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,20 @@ def get_bias(self) -> torch.Tensor:
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()

def init_non_persistent_buffers(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Initialize non-persistent buffers."""
device = device or self.relative_position_bias_table.device
prefix_tokens = 1 if self.bias_shape[0] > self.window_area else 0
self.register_buffer(
'relative_position_index',
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0, device=device).view(-1),
persistent=False,
)


def gen_relative_log_coords(
win_size: Tuple[int, int],
Expand Down Expand Up @@ -367,6 +381,8 @@ def __init__(
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
self.mode = mode
self.pretrained_window_size = pretrained_window_size
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
Expand Down Expand Up @@ -415,6 +431,25 @@ def get_bias(self) -> torch.Tensor:
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()

def init_non_persistent_buffers(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Initialize non-persistent buffers."""
device = device or self.mlp.fc1.weight.device
dtype = dtype or self.mlp.fc1.weight.dtype
self.register_buffer(
'relative_position_index',
gen_relative_position_index(self.window_size, device=device).view(-1),
persistent=False,
)
self.register_buffer(
'rel_coords_log',
gen_relative_log_coords(self.window_size, self.pretrained_window_size, mode=self.mode, device=device, dtype=dtype),
persistent=False,
)


def generate_lookup_tensor(
length: int,
Expand Down Expand Up @@ -519,3 +554,14 @@ def get_bias(self) -> torch.Tensor:

def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()

def init_non_persistent_buffers(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
"""Initialize non-persistent buffers."""
device = device or self.relative_position_bias_table.device
dtype = dtype or self.relative_position_bias_table.dtype
self.register_buffer('height_lookup', generate_lookup_tensor(self.window_size[0], device=device, dtype=dtype), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(self.window_size[1], device=device, dtype=dtype), persistent=False)
Loading