@@ -984,9 +984,6 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
984984 self .inner_model , self .conds , self .loaded_models = comfy .sampler_helpers .prepare_sampling (self .model_patcher , noise .shape , self .conds , self .model_options )
985985 device = self .model_patcher .load_device
986986
987- if denoise_mask is not None :
988- denoise_mask = comfy .sampler_helpers .prepare_mask (denoise_mask , noise .shape , device )
989-
990987 noise = noise .to (device )
991988 latent_image = latent_image .to (device )
992989 sigmas = sigmas .to (device )
@@ -1013,6 +1010,24 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10131010 else :
10141011 latent_shapes = [latent_image .shape ]
10151012
1013+ if denoise_mask is not None :
1014+ if denoise_mask .is_nested :
1015+ denoise_masks = denoise_mask .unbind ()
1016+ denoise_masks = denoise_masks [:len (latent_shapes )]
1017+ else :
1018+ denoise_masks = [denoise_mask ]
1019+
1020+ for i in range (len (denoise_masks ), len (latent_shapes )):
1021+ denoise_masks .append (torch .ones (latent_shapes [i ]))
1022+
1023+ for i in range (len (denoise_masks )):
1024+ denoise_masks [i ] = comfy .sampler_helpers .prepare_mask (denoise_masks [i ], latent_shapes [i ], self .model_patcher .load_device )
1025+
1026+ if len (denoise_masks ) > 1 :
1027+ denoise_mask , _ = comfy .utils .pack_latents (denoise_masks )
1028+ else :
1029+ denoise_mask = denoise_masks [0 ]
1030+
10161031 self .conds = {}
10171032 for k in self .original_conds :
10181033 self .conds [k ] = list (map (lambda a : a .copy (), self .original_conds [k ]))
0 commit comments