Coverage for watermark / robin / watermark_generator.py: 94.70%
151 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from torch.utils.data import Dataset
2import os
3import numpy as np
4from PIL import Image
5from evaluation.dataset import BaseDataset
6from typing import Optional, Union, List, Callable, Tuple
7import torch
8import math
9from tqdm import tqdm
10from torch.amp import GradScaler, autocast
11import torch.nn.functional as F
12from diffusers import StableDiffusionPipeline
13from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14from accelerate import Accelerator
15from transformers.models.clip.modeling_clip import CLIPTextModel
16from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
17from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
18from diffusers.schedulers import DPMSolverMultistepScheduler
19import logging
20from utils.utils import set_random_seed
21from utils.media_utils import *
22import copy
23from diffusers.utils import BaseOutput
24import PIL
25import time
27logging.basicConfig(
28 level=logging.INFO, # seg logger level
29 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # set logger format
30 handlers=[
31 logging.StreamHandler(), # output to terminal
32 # logging.FileHandler('logs/output.log', mode='a', encoding='utf-8') # output to file
33 ]
34)
36logger = logging.getLogger(__name__) # pylint: disable=invalid-name
38# class OptimizedDataset(Dataset):
39# def __init__(
40# self,
41# data_root,
42# custom_dataset: BaseDataset,
43# size=512,
44# repeats=10,
45# interpolation="bicubic",
46# set="train",
47# center_crop=False,
48# ):
50# self.data_root = data_root
51# self.size = size
52# self.center_crop = center_crop
54# file_list = os.listdir(self.data_root)
55# file_list.sort(key=lambda x: int(x.split('-')[-1].split('.')[0])) # ori-lg7.5-xx.jpg
56# self.image_paths = [os.path.join(self.data_root, file_path) for file_path in file_list]
57# self.dataset = custom_dataset
59# self.num_images = len(self.image_paths)
60# self._length = self.num_images
62# if set == "train":
63# self._length = self.num_images * repeats
65# self.interpolation = {
66# "bilinear": Image.BILINEAR,
67# "bicubic": Image.BICUBIC,
68# "lanczos": Image.LANCZOS,
69# }[interpolation]
71# def __len__(self):
72# return self._length
74# def __getitem__(self, i):
75# example = {}
76# image = Image.open(self.image_paths[i % self.num_images])
78# if not image.mode == "RGB":
79# image = image.convert("RGB")
81# text = self.dataset[i % self.num_images] # __getitem__ of BaseDataset: return prompt[idx]
82# example["prompt"] = text
84# # default to score-sde preprocessing
85# img = np.array(image).astype(np.uint8)
87# if self.center_crop:
88# crop = min(img.shape[0], img.shape[1])
89# h, w, = (
90# img.shape[0],
91# img.shape[1],
92# )
93# img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
95# image = Image.fromarray(img)
96# image = image.resize((self.size, self.size), resample=self.interpolation)
98# example["pixel_values"] = pil_to_torch(image, normalize=False) # scale to [0, 1]
100# return example
103def circle_mask(size=64, r_max=10, r_min=0, x_offset=0, y_offset=0):
104 # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3
105 x0 = y0 = size // 2
106 x0 += x_offset
107 y0 += y_offset
108 y, x = np.ogrid[:size, :size]
109 y = y[::-1]
111 return (((x - x0)**2 + (y-y0)**2)<= r_max**2) & (((x - x0)**2 + (y-y0)**2) > r_min**2)
113def get_watermarking_mask(init_latents_w, args, device):
114 watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device)
116 # Use dynamic size from input latents
117 latent_size = init_latents_w.shape[-1]
119 if args.w_mask_shape == 'circle':
120 np_mask = circle_mask(latent_size, r_max=args.w_up_radius, r_min=args.w_low_radius)
122 torch_mask = torch.tensor(np_mask).to(device)
124 if args.w_channel == -1:
125 # all channels
126 watermarking_mask[:, :] = torch_mask
127 else:
128 watermarking_mask[:, args.w_channel] = torch_mask
129 elif args.w_mask_shape == 'square':
130 anchor_p = latent_size // 2
131 if args.w_channel == -1:
132 # all channels
133 watermarking_mask[:, :, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
134 else:
135 watermarking_mask[:, args.w_channel, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True
136 elif args.w_mask_shape == 'no':
137 pass
138 else:
139 raise NotImplementedError(f'w_mask_shape: {args.w_mask_shape}')
141 return watermarking_mask
143# def get_watermarking_pattern(pipe, args, device, shape=None):
144# set_random_seed(args.w_seed)
145# # set_random_seed(10) # test weak high freq watermark
146# if shape is not None:
147# gt_init = torch.randn(*shape, device=device)#.type(torch.complex32)
148# else:
149# gt_init = get_random_latents(pipe=pipe)
151# if 'seed_ring' in args.w_pattern: # spacial
152# gt_patch = gt_init
154# gt_patch_tmp = copy.deepcopy(gt_patch)
155# for i in range(args.w_up_radius, args.w_low_radius, -1):
156# tmp_mask = circle_mask(gt_init.shape[-1], r_max=args.w_up_radius, r_min=args.w_low_radius)
157# tmp_mask = torch.tensor(tmp_mask).to(device)
159# for j in range(gt_patch.shape[1]):
160# gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
161# elif 'seed_zeros' in args.w_pattern:
162# gt_patch = gt_init * 0
163# elif 'seed_rand' in args.w_pattern:
164# gt_patch = gt_init
165# elif 'rand' in args.w_pattern:
166# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
167# gt_patch[:] = gt_patch[0]
168# elif 'zeros' in args.w_pattern:
169# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
170# elif 'const' in args.w_pattern:
171# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
172# gt_patch += args.w_pattern_const
173# elif 'ring' in args.w_pattern:
174# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
176# gt_patch_tmp = copy.deepcopy(gt_patch)
177# for i in range(args.w_up_radius, args.w_low_radius, -1):
178# tmp_mask = circle_mask(gt_init.shape[-1],r_max=i,r_min=args.w_low_radius)
179# tmp_mask = torch.tensor(tmp_mask).to(device)
181# for j in range(gt_patch.shape[1]):
182# gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
184# return gt_patch
186def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args):
187 init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents_w), dim=(-1, -2))
188 gt_patch = gt_patch.to(init_latents_w_fft.dtype)
189 if args.w_injection == 'complex':
190 init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone() # complexhalf = complexfloat
191 elif args.w_injection == 'seed':
192 init_latents_w[watermarking_mask] = gt_patch[watermarking_mask].clone()
193 return init_latents_w
194 else:
195 NotImplementedError(f'w_injection: {args.w_injection}')
197 init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real
199 return init_latents_w
202# def freeze_params(params):
203# for param in params:
204# param.requires_grad = False
206# def to_ring(latent_fft, args):
207# # Calculate mean value for each ring
208# num_rings = args.w_up_radius - args.w_low_radius
209# r_max = args.w_up_radius
210# for i in range(num_rings):
211# # ring_mask = mask[..., (radii[i * 2] <= distances) & (distances < radii[i * 2 + 1])]
212# ring_mask = circle_mask(latent_fft.shape[-1], r_max=r_max, r_min=r_max-1)
213# ring_mean = latent_fft[:, args.w_channel,ring_mask].real.mean().item()
214# # print(f'ring mean: {ring_mean}')
215# latent_fft[:, args.w_channel,ring_mask] = ring_mean
216# r_max = r_max - 1
218# return latent_fft
220# def optimizer_wm_prompt(pipe: StableDiffusionPipeline,
221# dataloader: OptimizedDataset,
222# hyperparameters: dict,
223# mask: torch.Tensor,
224# opt_wm: torch.Tensor,
225# save_path: str,
226# args: dict,
227# generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
228# eta: float = 0.0,) -> tuple[torch.Tensor, torch.Tensor]:
229# train_batch_size = hyperparameters["train_batch_size"]
230# gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
231# learning_rate = hyperparameters["learning_rate"]
232# max_train_steps = hyperparameters["max_train_steps"]
233# output_dir = hyperparameters["output_dir"]
234# gradient_checkpointing = hyperparameters["gradient_checkpointing"]
235# original_guidance_scale = hyperparameters["guidance_scale"]
236# optimized_guidance_scale = hyperparameters["optimized_guidance_scale"]
238# # Check if checkpoint exists
239# checkpoint_path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{max_train_steps}.pt")
240# # checkpoint_path = "/workspace/panleyi/gs/ROBIN/ckpts/optimized_wm5-30_embedding-step-2000.pt"
241# if os.path.exists(checkpoint_path):
242# logger.info(f"Loading checkpoint from {checkpoint_path}")
243# checkpoint = torch.load(checkpoint_path)
244# opt_wm = checkpoint['opt_wm'].to(pipe.device)
245# opt_wm_embedding = checkpoint['opt_acond'].to(pipe.device)
246# return opt_wm, opt_wm_embedding
248# text_encoder: CLIPTextModel = pipe.text_encoder
249# unet: UNet2DConditionModel = pipe.unet
250# vae: AutoencoderKL = pipe.vae
251# scheduler: DPMSolverMultistepScheduler = pipe.scheduler
253# freeze_params(vae.parameters())
254# freeze_params(unet.parameters())
255# freeze_params(text_encoder.parameters())
257# accelerator = Accelerator(
258# gradient_accumulation_steps=gradient_accumulation_steps,
259# mixed_precision=hyperparameters["mixed_precision"]
260# )
262# if gradient_checkpointing:
263# text_encoder.gradient_checkpointing_enable()
264# unet.enable_gradient_checkpointing()
266# if hyperparameters["scale_lr"]:
267# learning_rate = (
268# learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
269# )
271# tester_prompt = '' # assume at the detection time, the original prompt is unknown
272# # null text, text_embedding.dtype = torch.float16
273# do_classifier_free_guidance = False # guidance_scale = 1.0
274# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
275# prompt=tester_prompt,
276# device=pipe.device,
277# do_classifier_free_guidance=do_classifier_free_guidance,
278# num_images_per_prompt=1,
279# )
281# text_embeddings = prompt_embeds
283# extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
285# unet, text_encoder, dataloader,text_embeddings = accelerator.prepare(
286# unet, text_encoder, dataloader, text_embeddings
287# )
289# weight_dtype = torch.float32
290# if accelerator.mixed_precision == "fp16":
291# weight_dtype = torch.float16
292# elif accelerator.mixed_precision == "bf16":
293# weight_dtype = torch.bfloat16
295# # Move vae and unet to device
296# vae.to(accelerator.device, dtype=weight_dtype)
297# unet.to(accelerator.device, dtype=weight_dtype)
299# # Keep vae in eval mode as we don't train it
300# vae.eval()
301# # Keep unet in train mode to enable gradient checkpointing
302# unet.train()
304# # We need to recalculate our total training steps as the size of the training dataloader may have changed.
305# num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps)
306# num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
308# # Train!
309# total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
311# logger.info("***** Running training *****")
312# logger.info(f" Num examples = {len(dataloader)}")
313# logger.info(f" Instantaneous batch size per device = {train_batch_size}")
314# logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
315# logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
316# logger.info(f" Total optimization steps = {max_train_steps}")
317# # Only show the progress bar once on each machine.
318# progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
319# progress_bar.set_description("Steps")
320# global_step = 0
322# scaler = GradScaler(device=accelerator.device)
323# # pipe.scheduler.set_timesteps(1000) # need for compute the next state
325# do_classifier_free_guidance = False # guidance_scale = 1.0
326# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
327# prompt='',
328# device=pipe.device,
329# do_classifier_free_guidance=do_classifier_free_guidance,
330# num_images_per_prompt=1,
331# )
333# opt_wm_embedding = prompt_embeds
334# null_embedding = opt_wm_embedding.clone()
335# total_time = 0
336# with autocast(device_type=accelerator.device.type):
337# for epoch in range(num_train_epochs):
338# for step, batch in enumerate(dataloader):
339# with accelerator.accumulate(unet):
340# # Convert images to latent space
341# gt_tensor = batch["pixel_values"]
342# image = 2.0 * gt_tensor - 1.0
343# latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample().detach()
344# latents = latents * 0.18215
345# # Sample noise that we'll add to the latents
346# noise = torch.randn_like(latents)
347# bsz = latents.shape[0]
348# # Sample a random timestep for each image
349# ori_timesteps = torch.randint(200, 300, (bsz,), device=latents.device).long() # 35~40steps
350# timesteps = len(scheduler) - 1 - ori_timesteps
352# # Add noise to the latents according to the noise magnitude at each timestep
353# noisy_latents = scheduler.add_noise(latents, noise, timesteps)
354# opt_wm = opt_wm.to(noisy_latents.device).to(torch.complex64) # add wm to latents
357# ### detailed the inject_watermark function for fft.grad
358# init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(noisy_latents), dim=(-1, -2))
359# init_latents_w_fft[mask] = opt_wm[mask].clone()
360# init_latents_w_fft.requires_grad = True
361# noisy_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real
362# ### Get the text embedding for conditioning CFG
363# prompt = batch["prompt"]
364# do_classifier_free_guidance = False # guidance_scale = 1.0
365# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
366# prompt=prompt,
367# device=pipe.device,
368# do_classifier_free_guidance=do_classifier_free_guidance,
369# num_images_per_prompt=1,
370# )
372# cond_embedding = prompt_embeds
373# text_embeddings = torch.cat([opt_wm_embedding, cond_embedding, null_embedding])
374# text_embeddings.requires_grad = True
376# ### Predict the noise residual with CFG
377# latent_model_input = torch.cat([noisy_latents] * 3)
378# latent_model_input = scheduler.scale_model_input(latent_model_input, timesteps)
379# noise_pred = unet(latent_model_input, ori_timesteps, encoder_hidden_states=text_embeddings).sample
380# noise_pred_wm, noise_pred_text, noise_pred_null = noise_pred.chunk(3)
381# noise_pred = noise_pred_null + original_guidance_scale * (noise_pred_text - noise_pred_null) + optimized_guidance_scale * (noise_pred_wm - noise_pred_null) # different guidance scale
384# ### get the predicted x0 tensor
385# scheduler._init_step_index(timesteps)
386# x0_latents = scheduler.convert_model_output(model_output=noise_pred, sample=noisy_latents) #predict x0 in one-step
387# x0_tensor = decode_media_latents(pipe=pipe, latents=x0_latents)
389# loss_noise = F.mse_loss(x0_tensor.float(), gt_tensor.float(), reduction="mean") # pixel alignment
390# loss_wm = torch.mean(torch.abs(opt_wm[mask].real))
391# loss_constrain = F.mse_loss(noise_pred_wm.float(), noise_pred_null.float(), reduction="mean") # prompt constraint
393# ### optimize wm pattern and uncond prompt alternately
394# if (global_step // 500) % 2 == 0:
395# loss = 10 * loss_noise + loss_constrain - 0.00001 * loss_wm # opt wm pattern
396# accelerator.backward(loss)
397# with torch.no_grad():
398# grads = init_latents_w_fft.grad
399# init_latents_w_fft = init_latents_w_fft - 1.0 * grads # update wm pattern
400# init_latents_w_fft = to_ring(init_latents_w_fft, args)
401# opt_wm = init_latents_w_fft.detach()
402# else:
403# loss = 10 * loss_noise + loss_constrain # opt prompt
404# accelerator.backward(loss)
405# with torch.no_grad():
406# grads = text_embeddings.grad
407# text_embeddings = text_embeddings - 5e-04 * grads
408# opt_wm_embedding = text_embeddings[0].unsqueeze(0).detach() # update acond embedding
411# print(f'global_step: {global_step}, loss_mse: {loss_noise}, loss_wm: {loss_wm}, loss_cons: {loss_constrain},loss: {loss}')
413# # Checks if the accelerator has performed an optimization step behind the scenes
414# if accelerator.sync_gradients:
415# progress_bar.update(1)
416# global_step += 1
417# if global_step % hyperparameters["save_steps"] == 0:
418# path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{global_step}.pt")
419# torch.save({'opt_acond': opt_wm_embedding, 'opt_wm': opt_wm.cpu()}, path)
421# logs = {"loss": loss.detach().item()}
422# progress_bar.set_postfix(**logs)
424# if global_step >= max_train_steps:
425# break
427# accelerator.wait_for_everyone()
429# return opt_wm, opt_wm_embedding
431class ROBINStableDiffusionPipelineOutput(BaseOutput):
432 images: Union[List[PIL.Image.Image], np.ndarray]
433 nsfw_content_detected: Optional[List[bool]]
434 init_latents: Optional[torch.FloatTensor]
435 latents: Optional[torch.FloatTensor]
436 inner_latents: Optional[List[torch.FloatTensor]]
438@torch.no_grad()
439def ROBINWatermarkedImageGeneration(
440 pipe: StableDiffusionPipeline,
441 prompt: Union[str, List[str]],
442 height: Optional[int] = None,
443 width: Optional[int] = None,
444 num_inference_steps: int = 50,
445 guidance_scale: float = 3.5,
446 optimized_guidance_scale: float = 3.5,
447 negative_prompt: Optional[Union[str, List[str]]] = None,
448 num_images_per_prompt: Optional[int] = 1,
449 eta: float = 0.0,
450 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
451 latents: Optional[torch.FloatTensor] = None,
452 output_type: Optional[str] = "pil",
453 return_dict: bool = True,
454 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
455 callback_steps: Optional[int] = 1,
456 watermarking_mask: Optional[torch.BoolTensor] = None,
457 watermarking_step: int = None,
458 args = None,
459 gt_patch = None,
460 opt_acond = None
461):
462 r"""
463 Function invoked when calling the pipeline for generation.
465 Args:
466 prompt (`str` or `List[str]`):
467 The prompt or prompts to guide the image generation.
468 height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
469 The height in pixels of the generated image.
470 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
471 The width in pixels of the generated image.
472 num_inference_steps (`int`, *optional*, defaults to 50):
473 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
474 expense of slower inference.
475 original_guidance_scale (`float`, *optional*, defaults to 3.5):
476 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
477 `original_guidance_scale` is defined as `w` of equation 2. of [Imagen
478 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `original_guidance_scale >
479 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
480 usually at the expense of lower image quality.
481 optimized_guidance_scale (`float`, *optional*, defaults to 3.5):
482 TODO: add description
483 negative_prompt (`str` or `List[str]`, *optional*):
484 The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
485 if `original_guidance_scale` is less than `1`).
486 num_images_per_prompt (`int`, *optional*, defaults to 1):
487 The number of images to generate per prompt.
488 eta (`float`, *optional*, defaults to 0.0):
489 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
490 [`schedulers.DDIMScheduler`], will be ignored for others.
491 generator (`torch.Generator`, *optional*):
492 One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
493 to make generation deterministic.
494 latents (`torch.FloatTensor`, *optional*):
495 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
496 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
497 tensor will ge generated by sampling using the supplied random `generator`.
498 output_type (`str`, *optional*, defaults to `"pil"`):
499 The output format of the generate image. Choose between
500 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
501 return_dict (`bool`, *optional*, defaults to `True`):
502 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
503 plain tuple.
504 callback (`Callable`, *optional*):
505 A function that will be called every `callback_steps` steps during inference. The function will be
506 called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
507 callback_steps (`int`, *optional*, defaults to 1):
508 The frequency at which the `callback` function will be called. If not specified, the callback will be
509 called at every step.
511 Returns:
512 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
513 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
514 When returning a tuple, the first element is a list with the generated images, and the second element is a
515 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
516 (nsfw) content, according to the `safety_checker`.
517 """
518 # print('got new version')
519 inner_latents = []
520 # 0. Default height and width to unet
521 height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
522 width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor
524 # 1. Check inputs. Raise error if not correct
525 pipe.check_inputs(prompt, height, width, callback_steps)
527 # 2. Define call parameters
528 batch_size = 1 if isinstance(prompt, str) else len(prompt)
529 device = pipe._execution_device
530 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
531 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
532 # corresponds to doing no classifier free guidance.
533 do_classifier_free_guidance = guidance_scale > 1.0
535 # 3. Encode input prompt
536 # Use encode_prompt instead of _encode_prompt for compatibility with newer diffusers versions
537 prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
538 prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
539 )
541 # Concatenate for classifier free guidance
542 if do_classifier_free_guidance:
543 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
544 else:
545 text_embeddings = prompt_embeds
547 # 4. Prepare timesteps
548 pipe.scheduler.set_timesteps(num_inference_steps, device=device)
549 timesteps = pipe.scheduler.timesteps
551 # 5. Prepare latent variables
552 num_channels_latents = pipe.unet.in_channels
553 latents = pipe.prepare_latents(
554 batch_size * num_images_per_prompt,
555 num_channels_latents,
556 height,
557 width,
558 text_embeddings.dtype,
559 device,
560 generator,
561 latents,
562 )
564 init_latents = copy.deepcopy(latents)
566 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
567 extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
569 inner_latents.append(init_latents)
571 # 7. Denoising loop
572 max_train_steps=1 #100
573 latents_wm = None
574 text_embeddings_opt = None
575 num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
577 start_time = time.time()
578 with pipe.progress_bar(total=num_inference_steps) as progress_bar:
579 for i, t in enumerate(timesteps):
580 if (watermarking_step is not None) and (i >= watermarking_step):
581 mask = watermarking_mask # mask from outside
582 if i == watermarking_step:
583 latents_wm = inject_watermark(latents, mask,gt_patch, args) # inject latent watermark
584 inner_latents[-1] = latents_wm
585 if opt_acond is not None:
586 uncond, cond = text_embeddings.chunk(2)
587 opt_acond = opt_acond.to(cond.dtype)
588 text_embeddings_opt = torch.cat([uncond, opt_acond, cond]) # opt as another cond
589 else:
590 text_embeddings_opt = text_embeddings.clone()
591 # if lguidance is not None:
592 # guidance_scale = lguidance
594 latents_wm, _ = xn1_latents_3(pipe,latents_wm,do_classifier_free_guidance,t
595 ,text_embeddings_opt,guidance_scale,optimized_guidance_scale,**extra_step_kwargs)
597 if (watermarking_step is None) or (watermarking_step is not None and i < watermarking_step):
598 latents, _ = xn1_latents(pipe,latents,do_classifier_free_guidance,t
599 ,text_embeddings,guidance_scale,**extra_step_kwargs)
601 # call the callback, if provided
602 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
603 progress_bar.update()
604 if callback is not None and i % callback_steps == 0:
605 callback(i, t, latents)
607 if (watermarking_step is not None and i < watermarking_step) or (watermarking_step is None):
608 inner_latents.append(latents) # save for memory
609 else:
610 inner_latents.append(latents_wm)
612 if watermarking_step is not None and watermarking_step == 50:
613 latents_wm = inject_watermark(latents, watermarking_mask,gt_patch, args) # inject latent watermark
614 inner_latents[-1] = latents_wm
616 end_time = time.time()
617 execution_time = end_time - start_time
618 # 8. Post-processing
619 if latents_wm is not None:
620 # Convert latents to the same dtype as VAE
621 latents_wm = latents_wm.to(dtype=pipe.vae.dtype)
622 image = pipe.decode_latents(latents_wm)
623 else:
624 # Convert latents to the same dtype as VAE
625 latents = latents.to(dtype=pipe.vae.dtype)
626 image = pipe.decode_latents(latents)
628 # 9. Run safety checker
629 image, has_nsfw_concept = pipe.run_safety_checker(image, device, text_embeddings.dtype)
631 # 10. Convert to PIL
632 if output_type == "pil":
633 image = pipe.numpy_to_pil(image)
635 if not return_dict:
636 return (image, has_nsfw_concept)
637 if text_embeddings_opt is not None:
638 return ROBINStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents, latents=latents, inner_latents=inner_latents,gt_patch=gt_patch,opt_acond=text_embeddings_opt[0],time=execution_time)
639 else:
640 return ROBINStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents, latents=latents, inner_latents=inner_latents,gt_patch=gt_patch,time=execution_time)
642def xn1_latents_3(pipe,latents,do_classifier_free_guidance,t
643 ,text_embeddings,original_guidance_scale,optimized_guidance_scale,**extra_step_kwargs):
644 latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
645 latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
646 noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
647 if do_classifier_free_guidance:
648 noise_pred_uncond, noise_pred_text1, noise_pred_text2 = noise_pred.chunk(3)
649 noise_pred = noise_pred_uncond + original_guidance_scale * (noise_pred_text1 - noise_pred_uncond) + optimized_guidance_scale * (noise_pred_text2 - noise_pred_uncond)
650 latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
652 return latents, noise_pred
654def xn1_latents(pipe,latents,do_classifier_free_guidance,t
655 ,text_embeddings,guidance_scale,**extra_step_kwargs):
656 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
657 latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
658 noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
659 if do_classifier_free_guidance:
660 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
661 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
662 latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
663 return latents, noise_pred # Make sure to return both values