Skip to content

Commit 809ce68

Browse files
Support nested tensor denoise masks. (#11431)
1 parent cc4ddba commit 809ce68

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

comfy/samplers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -984,9 +984,6 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
984984
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
985985
device = self.model_patcher.load_device
986986

987-
if denoise_mask is not None:
988-
denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device)
989-
990987
noise = noise.to(device)
991988
latent_image = latent_image.to(device)
992989
sigmas = sigmas.to(device)
@@ -1013,6 +1010,24 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10131010
else:
10141011
latent_shapes = [latent_image.shape]
10151012

1013+
if denoise_mask is not None:
1014+
if denoise_mask.is_nested:
1015+
denoise_masks = denoise_mask.unbind()
1016+
denoise_masks = denoise_masks[:len(latent_shapes)]
1017+
else:
1018+
denoise_masks = [denoise_mask]
1019+
1020+
for i in range(len(denoise_masks), len(latent_shapes)):
1021+
denoise_masks.append(torch.ones(latent_shapes[i]))
1022+
1023+
for i in range(len(denoise_masks)):
1024+
denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device)
1025+
1026+
if len(denoise_masks) > 1:
1027+
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
1028+
else:
1029+
denoise_mask = denoise_masks[0]
1030+
10161031
self.conds = {}
10171032
for k in self.original_conds:
10181033
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))

0 commit comments

Comments
 (0)