Skip to content

Commit f12ac6a

Browse files
committed
Initial pass of an 'init_non_persistent_buffers' scheme, WIP.. needs more test and probably missed a few things
1 parent 4e651da commit f12ac6a

16 files changed

+514
-104
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
from .trace_utils import _assert, _float_to_int
151151
from .typing import LayerType, PadType, disable_compiler
152152
from .weight_init import (
153+
is_meta_device,
153154
trunc_normal_,
154155
trunc_normal_tf_,
155156
variance_scaling_,

timm/layers/blur_pool.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,24 @@ def __init__(
4646
self.stride = stride
4747
self.pad_mode = pad_mode
4848
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
49+
self.register_buffer(
50+
'filt',
51+
self._create_blur_filter(device=device, dtype=dtype),
52+
persistent=False,
53+
)
4954

55+
def _create_blur_filter(self, device=None, dtype=None) -> torch.Tensor:
56+
"""Create the blur filter tensor."""
5057
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
5158
coeffs = torch.tensor(
52-
[comb(filt_size - 1, k) for k in range(filt_size)],
59+
[comb(self.filt_size - 1, k) for k in range(self.filt_size)],
5360
device='cpu',
5461
dtype=torch.float32,
55-
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
62+
) / (2 ** (self.filt_size - 1)) # normalise so coefficients sum to 1
5663
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
57-
if channels is not None:
64+
if self.channels is not None:
5865
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
59-
60-
self.register_buffer(
61-
'filt',
62-
blur_filter.to(device=device, dtype=dtype),
63-
persistent=False,
64-
)
66+
return blur_filter.to(device=device, dtype=dtype)
6567

6668
def forward(self, x: torch.Tensor) -> torch.Tensor:
6769
x = F.pad(x, self.padding, mode=self.pad_mode)
@@ -73,6 +75,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7375
weight = self.filt
7476
return F.conv2d(x, weight, stride=self.stride, groups=channels)
7577

78+
def init_non_persistent_buffers(
79+
self,
80+
device: torch.device,
81+
dtype: torch.dtype,
82+
) -> None:
83+
"""Initialize non-persistent buffers."""
84+
self.register_buffer('filt', self._create_blur_filter(device=device, dtype=dtype), persistent=False)
85+
7686

7787
def _normalize_aa_layer(aa_layer: LayerType) -> Callable[..., nn.Module]:
7888
"""Map string shorthands to callables (class or partial)."""

timm/layers/lambda_layer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,16 @@ def forward(self, x):
156156
out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
157157
out = self.pool(out)
158158
return out
159+
160+
def init_non_persistent_buffers(
161+
self,
162+
device: Optional[torch.device] = None,
163+
dtype: Optional[torch.dtype] = None,
164+
) -> None:
165+
"""Initialize non-persistent buffers."""
166+
if self.rel_pos_indices is None:
167+
return
168+
device = device or self.qkv.weight.device
169+
# Compute feat_size from pos_emb shape: rel_size = 2 * feat_size - 1
170+
feat_size = ((self.pos_emb.shape[0] + 1) // 2, (self.pos_emb.shape[1] + 1) // 2)
171+
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size, device=device), persistent=False)

timm/layers/pos_embed_rel.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,20 @@ def get_bias(self) -> torch.Tensor:
311311
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
312312
return attn + self.get_bias()
313313

314+
def init_non_persistent_buffers(
315+
self,
316+
device: Optional[torch.device] = None,
317+
dtype: Optional[torch.dtype] = None,
318+
) -> None:
319+
"""Initialize non-persistent buffers."""
320+
device = device or self.relative_position_bias_table.device
321+
prefix_tokens = 1 if self.bias_shape[0] > self.window_area else 0
322+
self.register_buffer(
323+
'relative_position_index',
324+
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0, device=device).view(-1),
325+
persistent=False,
326+
)
327+
314328

315329
def gen_relative_log_coords(
316330
win_size: Tuple[int, int],
@@ -367,6 +381,8 @@ def __init__(
367381
self.prefix_tokens = prefix_tokens
368382
self.num_heads = num_heads
369383
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
384+
self.mode = mode
385+
self.pretrained_window_size = pretrained_window_size
370386
if mode == 'swin':
371387
self.bias_act = nn.Sigmoid()
372388
self.bias_gain = 16
@@ -415,6 +431,25 @@ def get_bias(self) -> torch.Tensor:
415431
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
416432
return attn + self.get_bias()
417433

434+
def init_non_persistent_buffers(
435+
self,
436+
device: Optional[torch.device] = None,
437+
dtype: Optional[torch.dtype] = None,
438+
) -> None:
439+
"""Initialize non-persistent buffers."""
440+
device = device or self.mlp.fc1.weight.device
441+
dtype = dtype or self.mlp.fc1.weight.dtype
442+
self.register_buffer(
443+
'relative_position_index',
444+
gen_relative_position_index(self.window_size, device=device).view(-1),
445+
persistent=False,
446+
)
447+
self.register_buffer(
448+
'rel_coords_log',
449+
gen_relative_log_coords(self.window_size, self.pretrained_window_size, mode=self.mode, device=device, dtype=dtype),
450+
persistent=False,
451+
)
452+
418453

419454
def generate_lookup_tensor(
420455
length: int,
@@ -519,3 +554,14 @@ def get_bias(self) -> torch.Tensor:
519554

520555
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
521556
return attn + self.get_bias()
557+
558+
def init_non_persistent_buffers(
559+
self,
560+
device: Optional[torch.device] = None,
561+
dtype: Optional[torch.dtype] = None,
562+
) -> None:
563+
"""Initialize non-persistent buffers."""
564+
device = device or self.relative_position_bias_table.device
565+
dtype = dtype or self.relative_position_bias_table.dtype
566+
self.register_buffer('height_lookup', generate_lookup_tensor(self.window_size[0], device=device, dtype=dtype), persistent=False)
567+
self.register_buffer('width_lookup', generate_lookup_tensor(self.window_size[1], device=device, dtype=dtype), persistent=False)

0 commit comments

Comments
 (0)