@@ -311,6 +311,18 @@ def get_bias(self) -> torch.Tensor:
311311 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
312312 return attn + self .get_bias ()
313313
314+ def init_non_persistent_buffers (
315+ self ,
316+ device : Optional [torch .device ] = None ,
317+ dtype : Optional [torch .dtype ] = None ,
318+ ) -> None :
319+ """Initialize non-persistent buffers."""
320+ device = device or self .relative_position_bias_table .device
321+ prefix_tokens = 1 if self .bias_shape [0 ] > self .window_area else 0
322+ self .relative_position_index .data = gen_relative_position_index (
323+ self .window_size , class_token = prefix_tokens > 0 , device = device
324+ ).view (- 1 )
325+
314326
315327def gen_relative_log_coords (
316328 win_size : Tuple [int , int ],
@@ -367,6 +379,8 @@ def __init__(
367379 self .prefix_tokens = prefix_tokens
368380 self .num_heads = num_heads
369381 self .bias_shape = (self .window_area ,) * 2 + (num_heads ,)
382+ self .mode = mode
383+ self .pretrained_window_size = pretrained_window_size
370384 if mode == 'swin' :
371385 self .bias_act = nn .Sigmoid ()
372386 self .bias_gain = 16
@@ -415,6 +429,25 @@ def get_bias(self) -> torch.Tensor:
415429 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
416430 return attn + self .get_bias ()
417431
432+ def init_non_persistent_buffers (
433+ self ,
434+ device : Optional [torch .device ] = None ,
435+ dtype : Optional [torch .dtype ] = None ,
436+ ) -> None :
437+ """Initialize non-persistent buffers."""
438+ device = device or self .mlp .fc1 .weight .device
439+ dtype = dtype or self .mlp .fc1 .weight .dtype
440+ self .relative_position_index .data = gen_relative_position_index (
441+ self .window_size , device = device
442+ ).view (- 1 )
443+ self .rel_coords_log .data = gen_relative_log_coords (
444+ self .window_size ,
445+ self .pretrained_window_size ,
446+ mode = self .mode ,
447+ device = device ,
448+ dtype = dtype ,
449+ )
450+
418451
419452def generate_lookup_tensor (
420453 length : int ,
@@ -519,3 +552,14 @@ def get_bias(self) -> torch.Tensor:
519552
520553 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
521554 return attn + self .get_bias ()
555+
556+ def init_non_persistent_buffers (
557+ self ,
558+ device : Optional [torch .device ] = None ,
559+ dtype : Optional [torch .dtype ] = None ,
560+ ) -> None :
561+ """Initialize non-persistent buffers."""
562+ device = device or self .relative_position_bias_table .device
563+ dtype = dtype or self .relative_position_bias_table .dtype
564+ self .height_lookup .data = generate_lookup_tensor (self .window_size [0 ], device = device , dtype = dtype )
565+ self .width_lookup .data = generate_lookup_tensor (self .window_size [1 ], device = device , dtype = dtype )
0 commit comments