@@ -311,6 +311,20 @@ def get_bias(self) -> torch.Tensor:
311311 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
312312 return attn + self .get_bias ()
313313
314+ def init_non_persistent_buffers (
315+ self ,
316+ device : Optional [torch .device ] = None ,
317+ dtype : Optional [torch .dtype ] = None ,
318+ ) -> None :
319+ """Initialize non-persistent buffers."""
320+ device = device or self .relative_position_bias_table .device
321+ prefix_tokens = 1 if self .bias_shape [0 ] > self .window_area else 0
322+ self .register_buffer (
323+ 'relative_position_index' ,
324+ gen_relative_position_index (self .window_size , class_token = prefix_tokens > 0 , device = device ).view (- 1 ),
325+ persistent = False ,
326+ )
327+
314328
315329def gen_relative_log_coords (
316330 win_size : Tuple [int , int ],
@@ -367,6 +381,8 @@ def __init__(
367381 self .prefix_tokens = prefix_tokens
368382 self .num_heads = num_heads
369383 self .bias_shape = (self .window_area ,) * 2 + (num_heads ,)
384+ self .mode = mode
385+ self .pretrained_window_size = pretrained_window_size
370386 if mode == 'swin' :
371387 self .bias_act = nn .Sigmoid ()
372388 self .bias_gain = 16
@@ -415,6 +431,25 @@ def get_bias(self) -> torch.Tensor:
415431 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
416432 return attn + self .get_bias ()
417433
434+ def init_non_persistent_buffers (
435+ self ,
436+ device : Optional [torch .device ] = None ,
437+ dtype : Optional [torch .dtype ] = None ,
438+ ) -> None :
439+ """Initialize non-persistent buffers."""
440+ device = device or self .mlp .fc1 .weight .device
441+ dtype = dtype or self .mlp .fc1 .weight .dtype
442+ self .register_buffer (
443+ 'relative_position_index' ,
444+ gen_relative_position_index (self .window_size , device = device ).view (- 1 ),
445+ persistent = False ,
446+ )
447+ self .register_buffer (
448+ 'rel_coords_log' ,
449+ gen_relative_log_coords (self .window_size , self .pretrained_window_size , mode = self .mode , device = device , dtype = dtype ),
450+ persistent = False ,
451+ )
452+
418453
419454def generate_lookup_tensor (
420455 length : int ,
@@ -519,3 +554,14 @@ def get_bias(self) -> torch.Tensor:
519554
520555 def forward (self , attn , shared_rel_pos : Optional [torch .Tensor ] = None ):
521556 return attn + self .get_bias ()
557+
558+ def init_non_persistent_buffers (
559+ self ,
560+ device : Optional [torch .device ] = None ,
561+ dtype : Optional [torch .dtype ] = None ,
562+ ) -> None :
563+ """Initialize non-persistent buffers."""
564+ device = device or self .relative_position_bias_table .device
565+ dtype = dtype or self .relative_position_bias_table .dtype
566+ self .register_buffer ('height_lookup' , generate_lookup_tensor (self .window_size [0 ], device = device , dtype = dtype ), persistent = False )
567+ self .register_buffer ('width_lookup' , generate_lookup_tensor (self .window_size [1 ], device = device , dtype = dtype ), persistent = False )
0 commit comments