Skip to content

Commit df4f3ff

Browse files
committed
Initial pass of an 'init_non_persistent_buffers' scheme, WIP.. needs more test and probably missed a few things
1 parent 4e651da commit df4f3ff

16 files changed

+512
-104
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
from .trace_utils import _assert, _float_to_int
151151
from .typing import LayerType, PadType, disable_compiler
152152
from .weight_init import (
153+
is_meta_device,
153154
trunc_normal_,
154155
trunc_normal_tf_,
155156
variance_scaling_,

timm/layers/blur_pool.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,24 @@ def __init__(
4646
self.stride = stride
4747
self.pad_mode = pad_mode
4848
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
49+
self.register_buffer(
50+
'filt',
51+
self._create_blur_filter(device=device, dtype=dtype),
52+
persistent=False,
53+
)
4954

55+
def _create_blur_filter(self, device=None, dtype=None) -> torch.Tensor:
56+
"""Create the blur filter tensor."""
5057
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
5158
coeffs = torch.tensor(
52-
[comb(filt_size - 1, k) for k in range(filt_size)],
59+
[comb(self.filt_size - 1, k) for k in range(self.filt_size)],
5360
device='cpu',
5461
dtype=torch.float32,
55-
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
62+
) / (2 ** (self.filt_size - 1)) # normalise so coefficients sum to 1
5663
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
57-
if channels is not None:
64+
if self.channels is not None:
5865
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
59-
60-
self.register_buffer(
61-
'filt',
62-
blur_filter.to(device=device, dtype=dtype),
63-
persistent=False,
64-
)
66+
return blur_filter.to(device=device, dtype=dtype)
6567

6668
def forward(self, x: torch.Tensor) -> torch.Tensor:
6769
x = F.pad(x, self.padding, mode=self.pad_mode)
@@ -73,6 +75,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7375
weight = self.filt
7476
return F.conv2d(x, weight, stride=self.stride, groups=channels)
7577

78+
def init_non_persistent_buffers(
79+
self,
80+
device: torch.device,
81+
dtype: torch.dtype,
82+
) -> None:
83+
"""Initialize non-persistent buffers."""
84+
self.filt.data = self._create_blur_filter(device=device, dtype=dtype)
85+
7686

7787
def _normalize_aa_layer(aa_layer: LayerType) -> Callable[..., nn.Module]:
7888
"""Map string shorthands to callables (class or partial)."""

timm/layers/lambda_layer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,16 @@ def forward(self, x):
156156
out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
157157
out = self.pool(out)
158158
return out
159+
160+
def init_non_persistent_buffers(
161+
self,
162+
device: Optional[torch.device] = None,
163+
dtype: Optional[torch.dtype] = None,
164+
) -> None:
165+
"""Initialize non-persistent buffers."""
166+
if self.rel_pos_indices is None:
167+
return
168+
device = device or self.qkv.weight.device
169+
# Compute feat_size from pos_emb shape: rel_size = 2 * feat_size - 1
170+
feat_size = ((self.pos_emb.shape[0] + 1) // 2, (self.pos_emb.shape[1] + 1) // 2)
171+
self.rel_pos_indices.data = rel_pos_indices(feat_size, device=device)

timm/layers/pos_embed_rel.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,18 @@ def get_bias(self) -> torch.Tensor:
311311
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
312312
return attn + self.get_bias()
313313

314+
def init_non_persistent_buffers(
315+
self,
316+
device: Optional[torch.device] = None,
317+
dtype: Optional[torch.dtype] = None,
318+
) -> None:
319+
"""Initialize non-persistent buffers."""
320+
device = device or self.relative_position_bias_table.device
321+
prefix_tokens = 1 if self.bias_shape[0] > self.window_area else 0
322+
self.relative_position_index.data = gen_relative_position_index(
323+
self.window_size, class_token=prefix_tokens > 0, device=device
324+
).view(-1)
325+
314326

315327
def gen_relative_log_coords(
316328
win_size: Tuple[int, int],
@@ -367,6 +379,8 @@ def __init__(
367379
self.prefix_tokens = prefix_tokens
368380
self.num_heads = num_heads
369381
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
382+
self.mode = mode
383+
self.pretrained_window_size = pretrained_window_size
370384
if mode == 'swin':
371385
self.bias_act = nn.Sigmoid()
372386
self.bias_gain = 16
@@ -415,6 +429,25 @@ def get_bias(self) -> torch.Tensor:
415429
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
416430
return attn + self.get_bias()
417431

432+
def init_non_persistent_buffers(
433+
self,
434+
device: Optional[torch.device] = None,
435+
dtype: Optional[torch.dtype] = None,
436+
) -> None:
437+
"""Initialize non-persistent buffers."""
438+
device = device or self.mlp.fc1.weight.device
439+
dtype = dtype or self.mlp.fc1.weight.dtype
440+
self.relative_position_index.data = gen_relative_position_index(
441+
self.window_size, device=device
442+
).view(-1)
443+
self.rel_coords_log.data = gen_relative_log_coords(
444+
self.window_size,
445+
self.pretrained_window_size,
446+
mode=self.mode,
447+
device=device,
448+
dtype=dtype,
449+
)
450+
418451

419452
def generate_lookup_tensor(
420453
length: int,
@@ -519,3 +552,14 @@ def get_bias(self) -> torch.Tensor:
519552

520553
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
521554
return attn + self.get_bias()
555+
556+
def init_non_persistent_buffers(
557+
self,
558+
device: Optional[torch.device] = None,
559+
dtype: Optional[torch.dtype] = None,
560+
) -> None:
561+
"""Initialize non-persistent buffers."""
562+
device = device or self.relative_position_bias_table.device
563+
dtype = dtype or self.relative_position_bias_table.dtype
564+
self.height_lookup.data = generate_lookup_tensor(self.window_size[0], device=device, dtype=dtype)
565+
self.width_lookup.data = generate_lookup_tensor(self.window_size[1], device=device, dtype=dtype)

0 commit comments

Comments
 (0)