Coverage for watermark / robin / watermark_generator.py: 94.70%

151 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from torch.utils.data import Dataset 

2import os 

3import numpy as np 

4from PIL import Image 

5from evaluation.dataset import BaseDataset 

6from typing import Optional, Union, List, Callable, Tuple 

7import torch 

8import math 

9from tqdm import tqdm 

10from torch.amp import GradScaler, autocast 

11import torch.nn.functional as F 

12from diffusers import StableDiffusionPipeline 

13from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 

14from accelerate import Accelerator 

15from transformers.models.clip.modeling_clip import CLIPTextModel 

16from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 

17from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL 

18from diffusers.schedulers import DPMSolverMultistepScheduler 

19import logging 

20from utils.utils import set_random_seed 

21from utils.media_utils import * 

22import copy 

23from diffusers.utils import BaseOutput 

24import PIL 

25import time 

26 

27logging.basicConfig( 

28 level=logging.INFO, # seg logger level 

29 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # set logger format 

30 handlers=[ 

31 logging.StreamHandler(), # output to terminal 

32 # logging.FileHandler('logs/output.log', mode='a', encoding='utf-8') # output to file 

33 ] 

34) 

35 

36logger = logging.getLogger(__name__) # pylint: disable=invalid-name 

37 

38# class OptimizedDataset(Dataset): 

39# def __init__( 

40# self, 

41# data_root, 

42# custom_dataset: BaseDataset, 

43# size=512, 

44# repeats=10, 

45# interpolation="bicubic", 

46# set="train", 

47# center_crop=False, 

48# ): 

49 

50# self.data_root = data_root 

51# self.size = size 

52# self.center_crop = center_crop 

53 

54# file_list = os.listdir(self.data_root) 

55# file_list.sort(key=lambda x: int(x.split('-')[-1].split('.')[0])) # ori-lg7.5-xx.jpg 

56# self.image_paths = [os.path.join(self.data_root, file_path) for file_path in file_list] 

57# self.dataset = custom_dataset 

58 

59# self.num_images = len(self.image_paths) 

60# self._length = self.num_images 

61 

62# if set == "train": 

63# self._length = self.num_images * repeats 

64 

65# self.interpolation = { 

66# "bilinear": Image.BILINEAR, 

67# "bicubic": Image.BICUBIC, 

68# "lanczos": Image.LANCZOS, 

69# }[interpolation] 

70 

71# def __len__(self): 

72# return self._length 

73 

74# def __getitem__(self, i): 

75# example = {} 

76# image = Image.open(self.image_paths[i % self.num_images]) 

77 

78# if not image.mode == "RGB": 

79# image = image.convert("RGB") 

80 

81# text = self.dataset[i % self.num_images] # __getitem__ of BaseDataset: return prompt[idx] 

82# example["prompt"] = text 

83 

84# # default to score-sde preprocessing 

85# img = np.array(image).astype(np.uint8) 

86 

87# if self.center_crop: 

88# crop = min(img.shape[0], img.shape[1]) 

89# h, w, = ( 

90# img.shape[0], 

91# img.shape[1], 

92# ) 

93# img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 

94 

95# image = Image.fromarray(img) 

96# image = image.resize((self.size, self.size), resample=self.interpolation) 

97 

98# example["pixel_values"] = pil_to_torch(image, normalize=False) # scale to [0, 1] 

99 

100# return example 

101 

102 

103def circle_mask(size=64, r_max=10, r_min=0, x_offset=0, y_offset=0): 

104 # reference: https://stackoverflow.com/questions/69687798/generating-a-soft-circluar-mask-using-numpy-python-3 

105 x0 = y0 = size // 2 

106 x0 += x_offset 

107 y0 += y_offset 

108 y, x = np.ogrid[:size, :size] 

109 y = y[::-1] 

110 

111 return (((x - x0)**2 + (y-y0)**2)<= r_max**2) & (((x - x0)**2 + (y-y0)**2) > r_min**2) 

112 

113def get_watermarking_mask(init_latents_w, args, device): 

114 watermarking_mask = torch.zeros(init_latents_w.shape, dtype=torch.bool).to(device) 

115 

116 # Use dynamic size from input latents 

117 latent_size = init_latents_w.shape[-1] 

118 

119 if args.w_mask_shape == 'circle': 

120 np_mask = circle_mask(latent_size, r_max=args.w_up_radius, r_min=args.w_low_radius) 

121 

122 torch_mask = torch.tensor(np_mask).to(device) 

123 

124 if args.w_channel == -1: 

125 # all channels 

126 watermarking_mask[:, :] = torch_mask 

127 else: 

128 watermarking_mask[:, args.w_channel] = torch_mask 

129 elif args.w_mask_shape == 'square': 

130 anchor_p = latent_size // 2 

131 if args.w_channel == -1: 

132 # all channels 

133 watermarking_mask[:, :, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True 

134 else: 

135 watermarking_mask[:, args.w_channel, anchor_p-args.w_radius:anchor_p+args.w_radius, anchor_p-args.w_radius:anchor_p+args.w_radius] = True 

136 elif args.w_mask_shape == 'no': 

137 pass 

138 else: 

139 raise NotImplementedError(f'w_mask_shape: {args.w_mask_shape}') 

140 

141 return watermarking_mask 

142 

143# def get_watermarking_pattern(pipe, args, device, shape=None): 

144# set_random_seed(args.w_seed) 

145# # set_random_seed(10) # test weak high freq watermark 

146# if shape is not None: 

147# gt_init = torch.randn(*shape, device=device)#.type(torch.complex32) 

148# else: 

149# gt_init = get_random_latents(pipe=pipe) 

150 

151# if 'seed_ring' in args.w_pattern: # spacial 

152# gt_patch = gt_init 

153 

154# gt_patch_tmp = copy.deepcopy(gt_patch) 

155# for i in range(args.w_up_radius, args.w_low_radius, -1): 

156# tmp_mask = circle_mask(gt_init.shape[-1], r_max=args.w_up_radius, r_min=args.w_low_radius) 

157# tmp_mask = torch.tensor(tmp_mask).to(device) 

158 

159# for j in range(gt_patch.shape[1]): 

160# gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() 

161# elif 'seed_zeros' in args.w_pattern: 

162# gt_patch = gt_init * 0 

163# elif 'seed_rand' in args.w_pattern: 

164# gt_patch = gt_init 

165# elif 'rand' in args.w_pattern: 

166# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) 

167# gt_patch[:] = gt_patch[0] 

168# elif 'zeros' in args.w_pattern: 

169# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 

170# elif 'const' in args.w_pattern: 

171# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 

172# gt_patch += args.w_pattern_const 

173# elif 'ring' in args.w_pattern: 

174# gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) 

175 

176# gt_patch_tmp = copy.deepcopy(gt_patch) 

177# for i in range(args.w_up_radius, args.w_low_radius, -1):  

178# tmp_mask = circle_mask(gt_init.shape[-1],r_max=i,r_min=args.w_low_radius) 

179# tmp_mask = torch.tensor(tmp_mask).to(device) 

180 

181# for j in range(gt_patch.shape[1]): 

182# gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() 

183 

184# return gt_patch  

185 

186def inject_watermark(init_latents_w, watermarking_mask, gt_patch, args): 

187 init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents_w), dim=(-1, -2)) 

188 gt_patch = gt_patch.to(init_latents_w_fft.dtype) 

189 if args.w_injection == 'complex': 

190 init_latents_w_fft[watermarking_mask] = gt_patch[watermarking_mask].clone() # complexhalf = complexfloat 

191 elif args.w_injection == 'seed': 

192 init_latents_w[watermarking_mask] = gt_patch[watermarking_mask].clone() 

193 return init_latents_w 

194 else: 

195 NotImplementedError(f'w_injection: {args.w_injection}') 

196 

197 init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real 

198 

199 return init_latents_w 

200 

201 

202# def freeze_params(params): 

203# for param in params: 

204# param.requires_grad = False 

205 

206# def to_ring(latent_fft, args): 

207# # Calculate mean value for each ring 

208# num_rings = args.w_up_radius - args.w_low_radius 

209# r_max = args.w_up_radius 

210# for i in range(num_rings): 

211# # ring_mask = mask[..., (radii[i * 2] <= distances) & (distances < radii[i * 2 + 1])] 

212# ring_mask = circle_mask(latent_fft.shape[-1], r_max=r_max, r_min=r_max-1) 

213# ring_mean = latent_fft[:, args.w_channel,ring_mask].real.mean().item() 

214# # print(f'ring mean: {ring_mean}') 

215# latent_fft[:, args.w_channel,ring_mask] = ring_mean 

216# r_max = r_max - 1 

217 

218# return latent_fft 

219 

220# def optimizer_wm_prompt(pipe: StableDiffusionPipeline, 

221# dataloader: OptimizedDataset, 

222# hyperparameters: dict,  

223# mask: torch.Tensor, 

224# opt_wm: torch.Tensor, 

225# save_path: str, 

226# args: dict, 

227# generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 

228# eta: float = 0.0,) -> tuple[torch.Tensor, torch.Tensor]: 

229# train_batch_size = hyperparameters["train_batch_size"] 

230# gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"] 

231# learning_rate = hyperparameters["learning_rate"] 

232# max_train_steps = hyperparameters["max_train_steps"] 

233# output_dir = hyperparameters["output_dir"] 

234# gradient_checkpointing = hyperparameters["gradient_checkpointing"] 

235# original_guidance_scale = hyperparameters["guidance_scale"] 

236# optimized_guidance_scale = hyperparameters["optimized_guidance_scale"] 

237 

238# # Check if checkpoint exists 

239# checkpoint_path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{max_train_steps}.pt") 

240# # checkpoint_path = "/workspace/panleyi/gs/ROBIN/ckpts/optimized_wm5-30_embedding-step-2000.pt" 

241# if os.path.exists(checkpoint_path): 

242# logger.info(f"Loading checkpoint from {checkpoint_path}") 

243# checkpoint = torch.load(checkpoint_path) 

244# opt_wm = checkpoint['opt_wm'].to(pipe.device) 

245# opt_wm_embedding = checkpoint['opt_acond'].to(pipe.device) 

246# return opt_wm, opt_wm_embedding 

247 

248# text_encoder: CLIPTextModel = pipe.text_encoder 

249# unet: UNet2DConditionModel = pipe.unet 

250# vae: AutoencoderKL = pipe.vae 

251# scheduler: DPMSolverMultistepScheduler = pipe.scheduler 

252 

253# freeze_params(vae.parameters()) 

254# freeze_params(unet.parameters()) 

255# freeze_params(text_encoder.parameters()) 

256 

257# accelerator = Accelerator( 

258# gradient_accumulation_steps=gradient_accumulation_steps, 

259# mixed_precision=hyperparameters["mixed_precision"] 

260# ) 

261 

262# if gradient_checkpointing: 

263# text_encoder.gradient_checkpointing_enable() 

264# unet.enable_gradient_checkpointing() 

265 

266# if hyperparameters["scale_lr"]: 

267# learning_rate = ( 

268# learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes 

269# ) 

270 

271# tester_prompt = '' # assume at the detection time, the original prompt is unknown 

272# # null text, text_embedding.dtype = torch.float16 

273# do_classifier_free_guidance = False # guidance_scale = 1.0 

274# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( 

275# prompt=tester_prompt,  

276# device=pipe.device,  

277# do_classifier_free_guidance=do_classifier_free_guidance, 

278# num_images_per_prompt=1, 

279# ) 

280 

281# text_embeddings = prompt_embeds 

282 

283# extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta) 

284 

285# unet, text_encoder, dataloader,text_embeddings = accelerator.prepare( 

286# unet, text_encoder, dataloader, text_embeddings 

287# )  

288 

289# weight_dtype = torch.float32 

290# if accelerator.mixed_precision == "fp16": 

291# weight_dtype = torch.float16 

292# elif accelerator.mixed_precision == "bf16": 

293# weight_dtype = torch.bfloat16 

294 

295# # Move vae and unet to device 

296# vae.to(accelerator.device, dtype=weight_dtype) 

297# unet.to(accelerator.device, dtype=weight_dtype) 

298 

299# # Keep vae in eval mode as we don't train it 

300# vae.eval() 

301# # Keep unet in train mode to enable gradient checkpointing 

302# unet.train() 

303 

304# # We need to recalculate our total training steps as the size of the training dataloader may have changed. 

305# num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) 

306# num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 

307 

308# # Train! 

309# total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps 

310 

311# logger.info("***** Running training *****") 

312# logger.info(f" Num examples = {len(dataloader)}") 

313# logger.info(f" Instantaneous batch size per device = {train_batch_size}") 

314# logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 

315# logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 

316# logger.info(f" Total optimization steps = {max_train_steps}") 

317# # Only show the progress bar once on each machine. 

318# progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) 

319# progress_bar.set_description("Steps") 

320# global_step = 0 

321 

322# scaler = GradScaler(device=accelerator.device) 

323# # pipe.scheduler.set_timesteps(1000) # need for compute the next state 

324 

325# do_classifier_free_guidance = False # guidance_scale = 1.0 

326# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( 

327# prompt='',  

328# device=pipe.device,  

329# do_classifier_free_guidance=do_classifier_free_guidance, 

330# num_images_per_prompt=1, 

331# ) 

332 

333# opt_wm_embedding = prompt_embeds 

334# null_embedding = opt_wm_embedding.clone() 

335# total_time = 0 

336# with autocast(device_type=accelerator.device.type): 

337# for epoch in range(num_train_epochs): 

338# for step, batch in enumerate(dataloader): 

339# with accelerator.accumulate(unet): 

340# # Convert images to latent space 

341# gt_tensor = batch["pixel_values"] 

342# image = 2.0 * gt_tensor - 1.0 

343# latents = vae.encode(image.to(dtype=weight_dtype)).latent_dist.sample().detach() 

344# latents = latents * 0.18215 

345# # Sample noise that we'll add to the latents 

346# noise = torch.randn_like(latents) 

347# bsz = latents.shape[0] 

348# # Sample a random timestep for each image 

349# ori_timesteps = torch.randint(200, 300, (bsz,), device=latents.device).long() # 35~40steps 

350# timesteps = len(scheduler) - 1 - ori_timesteps 

351 

352# # Add noise to the latents according to the noise magnitude at each timestep 

353# noisy_latents = scheduler.add_noise(latents, noise, timesteps) 

354# opt_wm = opt_wm.to(noisy_latents.device).to(torch.complex64) # add wm to latents 

355 

356 

357# ### detailed the inject_watermark function for fft.grad 

358# init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(noisy_latents), dim=(-1, -2)) 

359# init_latents_w_fft[mask] = opt_wm[mask].clone() 

360# init_latents_w_fft.requires_grad = True 

361# noisy_latents = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real 

362# ### Get the text embedding for conditioning CFG  

363# prompt = batch["prompt"] 

364# do_classifier_free_guidance = False # guidance_scale = 1.0 

365# prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( 

366# prompt=prompt,  

367# device=pipe.device,  

368# do_classifier_free_guidance=do_classifier_free_guidance, 

369# num_images_per_prompt=1, 

370# ) 

371 

372# cond_embedding = prompt_embeds 

373# text_embeddings = torch.cat([opt_wm_embedding, cond_embedding, null_embedding])  

374# text_embeddings.requires_grad = True 

375 

376# ### Predict the noise residual with CFG  

377# latent_model_input = torch.cat([noisy_latents] * 3) 

378# latent_model_input = scheduler.scale_model_input(latent_model_input, timesteps) 

379# noise_pred = unet(latent_model_input, ori_timesteps, encoder_hidden_states=text_embeddings).sample 

380# noise_pred_wm, noise_pred_text, noise_pred_null = noise_pred.chunk(3) 

381# noise_pred = noise_pred_null + original_guidance_scale * (noise_pred_text - noise_pred_null) + optimized_guidance_scale * (noise_pred_wm - noise_pred_null) # different guidance scale 

382 

383 

384# ### get the predicted x0 tensor 

385# scheduler._init_step_index(timesteps) 

386# x0_latents = scheduler.convert_model_output(model_output=noise_pred, sample=noisy_latents) #predict x0 in one-step 

387# x0_tensor = decode_media_latents(pipe=pipe, latents=x0_latents) 

388 

389# loss_noise = F.mse_loss(x0_tensor.float(), gt_tensor.float(), reduction="mean") # pixel alignment 

390# loss_wm = torch.mean(torch.abs(opt_wm[mask].real)) 

391# loss_constrain = F.mse_loss(noise_pred_wm.float(), noise_pred_null.float(), reduction="mean") # prompt constraint 

392 

393# ### optimize wm pattern and uncond prompt alternately 

394# if (global_step // 500) % 2 == 0: 

395# loss = 10 * loss_noise + loss_constrain - 0.00001 * loss_wm # opt wm pattern 

396# accelerator.backward(loss) 

397# with torch.no_grad():  

398# grads = init_latents_w_fft.grad 

399# init_latents_w_fft = init_latents_w_fft - 1.0 * grads # update wm pattern 

400# init_latents_w_fft = to_ring(init_latents_w_fft, args) 

401# opt_wm = init_latents_w_fft.detach() 

402# else: 

403# loss = 10 * loss_noise + loss_constrain # opt prompt 

404# accelerator.backward(loss) 

405# with torch.no_grad():  

406# grads = text_embeddings.grad 

407# text_embeddings = text_embeddings - 5e-04 * grads  

408# opt_wm_embedding = text_embeddings[0].unsqueeze(0).detach() # update acond embedding 

409 

410 

411# print(f'global_step: {global_step}, loss_mse: {loss_noise}, loss_wm: {loss_wm}, loss_cons: {loss_constrain},loss: {loss}') 

412 

413# # Checks if the accelerator has performed an optimization step behind the scenes 

414# if accelerator.sync_gradients: 

415# progress_bar.update(1) 

416# global_step += 1 

417# if global_step % hyperparameters["save_steps"] == 0: 

418# path = os.path.join(save_path, f"optimized_wm5-30_embedding-step-{global_step}.pt") 

419# torch.save({'opt_acond': opt_wm_embedding, 'opt_wm': opt_wm.cpu()}, path) 

420 

421# logs = {"loss": loss.detach().item()} 

422# progress_bar.set_postfix(**logs) 

423 

424# if global_step >= max_train_steps: 

425# break 

426 

427# accelerator.wait_for_everyone() 

428 

429# return opt_wm, opt_wm_embedding 

430 

431class ROBINStableDiffusionPipelineOutput(BaseOutput): 

432 images: Union[List[PIL.Image.Image], np.ndarray] 

433 nsfw_content_detected: Optional[List[bool]] 

434 init_latents: Optional[torch.FloatTensor] 

435 latents: Optional[torch.FloatTensor] 

436 inner_latents: Optional[List[torch.FloatTensor]] 

437 

438@torch.no_grad() 

439def ROBINWatermarkedImageGeneration( 

440 pipe: StableDiffusionPipeline, 

441 prompt: Union[str, List[str]], 

442 height: Optional[int] = None, 

443 width: Optional[int] = None, 

444 num_inference_steps: int = 50, 

445 guidance_scale: float = 3.5, 

446 optimized_guidance_scale: float = 3.5, 

447 negative_prompt: Optional[Union[str, List[str]]] = None, 

448 num_images_per_prompt: Optional[int] = 1, 

449 eta: float = 0.0, 

450 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 

451 latents: Optional[torch.FloatTensor] = None, 

452 output_type: Optional[str] = "pil", 

453 return_dict: bool = True, 

454 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 

455 callback_steps: Optional[int] = 1, 

456 watermarking_mask: Optional[torch.BoolTensor] = None, 

457 watermarking_step: int = None, 

458 args = None, 

459 gt_patch = None, 

460 opt_acond = None 

461): 

462 r""" 

463 Function invoked when calling the pipeline for generation. 

464 

465 Args: 

466 prompt (`str` or `List[str]`): 

467 The prompt or prompts to guide the image generation. 

468 height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 

469 The height in pixels of the generated image. 

470 width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 

471 The width in pixels of the generated image. 

472 num_inference_steps (`int`, *optional*, defaults to 50): 

473 The number of denoising steps. More denoising steps usually lead to a higher quality image at the 

474 expense of slower inference. 

475 original_guidance_scale (`float`, *optional*, defaults to 3.5): 

476 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 

477 `original_guidance_scale` is defined as `w` of equation 2. of [Imagen 

478 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `original_guidance_scale > 

479 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 

480 usually at the expense of lower image quality. 

481 optimized_guidance_scale (`float`, *optional*, defaults to 3.5): 

482 TODO: add description 

483 negative_prompt (`str` or `List[str]`, *optional*): 

484 The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 

485 if `original_guidance_scale` is less than `1`). 

486 num_images_per_prompt (`int`, *optional*, defaults to 1): 

487 The number of images to generate per prompt. 

488 eta (`float`, *optional*, defaults to 0.0): 

489 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 

490 [`schedulers.DDIMScheduler`], will be ignored for others. 

491 generator (`torch.Generator`, *optional*): 

492 One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 

493 to make generation deterministic. 

494 latents (`torch.FloatTensor`, *optional*): 

495 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 

496 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 

497 tensor will ge generated by sampling using the supplied random `generator`. 

498 output_type (`str`, *optional*, defaults to `"pil"`): 

499 The output format of the generate image. Choose between 

500 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 

501 return_dict (`bool`, *optional*, defaults to `True`): 

502 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 

503 plain tuple. 

504 callback (`Callable`, *optional*): 

505 A function that will be called every `callback_steps` steps during inference. The function will be 

506 called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 

507 callback_steps (`int`, *optional*, defaults to 1): 

508 The frequency at which the `callback` function will be called. If not specified, the callback will be 

509 called at every step. 

510 

511 Returns: 

512 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 

513 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 

514 When returning a tuple, the first element is a list with the generated images, and the second element is a 

515 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 

516 (nsfw) content, according to the `safety_checker`. 

517 """ 

518 # print('got new version') 

519 inner_latents = [] 

520 # 0. Default height and width to unet 

521 height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor 

522 width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor 

523 

524 # 1. Check inputs. Raise error if not correct 

525 pipe.check_inputs(prompt, height, width, callback_steps) 

526 

527 # 2. Define call parameters 

528 batch_size = 1 if isinstance(prompt, str) else len(prompt) 

529 device = pipe._execution_device 

530 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 

531 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 

532 # corresponds to doing no classifier free guidance. 

533 do_classifier_free_guidance = guidance_scale > 1.0 

534 

535 # 3. Encode input prompt 

536 # Use encode_prompt instead of _encode_prompt for compatibility with newer diffusers versions 

537 prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( 

538 prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt 

539 ) 

540 

541 # Concatenate for classifier free guidance 

542 if do_classifier_free_guidance: 

543 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

544 else: 

545 text_embeddings = prompt_embeds 

546 

547 # 4. Prepare timesteps 

548 pipe.scheduler.set_timesteps(num_inference_steps, device=device) 

549 timesteps = pipe.scheduler.timesteps 

550 

551 # 5. Prepare latent variables 

552 num_channels_latents = pipe.unet.in_channels 

553 latents = pipe.prepare_latents( 

554 batch_size * num_images_per_prompt, 

555 num_channels_latents, 

556 height, 

557 width, 

558 text_embeddings.dtype, 

559 device, 

560 generator, 

561 latents, 

562 ) 

563 

564 init_latents = copy.deepcopy(latents) 

565 

566 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 

567 extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta) 

568 

569 inner_latents.append(init_latents) 

570 

571 # 7. Denoising loop 

572 max_train_steps=1 #100 

573 latents_wm = None 

574 text_embeddings_opt = None 

575 num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order 

576 

577 start_time = time.time() 

578 with pipe.progress_bar(total=num_inference_steps) as progress_bar: 

579 for i, t in enumerate(timesteps): 

580 if (watermarking_step is not None) and (i >= watermarking_step): 

581 mask = watermarking_mask # mask from outside 

582 if i == watermarking_step: 

583 latents_wm = inject_watermark(latents, mask,gt_patch, args) # inject latent watermark 

584 inner_latents[-1] = latents_wm 

585 if opt_acond is not None: 

586 uncond, cond = text_embeddings.chunk(2) 

587 opt_acond = opt_acond.to(cond.dtype) 

588 text_embeddings_opt = torch.cat([uncond, opt_acond, cond]) # opt as another cond 

589 else: 

590 text_embeddings_opt = text_embeddings.clone() 

591 # if lguidance is not None: 

592 # guidance_scale = lguidance  

593 

594 latents_wm, _ = xn1_latents_3(pipe,latents_wm,do_classifier_free_guidance,t 

595 ,text_embeddings_opt,guidance_scale,optimized_guidance_scale,**extra_step_kwargs) 

596 

597 if (watermarking_step is None) or (watermarking_step is not None and i < watermarking_step): 

598 latents, _ = xn1_latents(pipe,latents,do_classifier_free_guidance,t 

599 ,text_embeddings,guidance_scale,**extra_step_kwargs) 

600 

601 # call the callback, if provided 

602 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): 

603 progress_bar.update() 

604 if callback is not None and i % callback_steps == 0: 

605 callback(i, t, latents) 

606 

607 if (watermarking_step is not None and i < watermarking_step) or (watermarking_step is None): 

608 inner_latents.append(latents) # save for memory 

609 else: 

610 inner_latents.append(latents_wm) 

611 

612 if watermarking_step is not None and watermarking_step == 50: 

613 latents_wm = inject_watermark(latents, watermarking_mask,gt_patch, args) # inject latent watermark 

614 inner_latents[-1] = latents_wm 

615 

616 end_time = time.time() 

617 execution_time = end_time - start_time 

618 # 8. Post-processing 

619 if latents_wm is not None: 

620 # Convert latents to the same dtype as VAE 

621 latents_wm = latents_wm.to(dtype=pipe.vae.dtype) 

622 image = pipe.decode_latents(latents_wm) 

623 else: 

624 # Convert latents to the same dtype as VAE 

625 latents = latents.to(dtype=pipe.vae.dtype) 

626 image = pipe.decode_latents(latents) 

627 

628 # 9. Run safety checker 

629 image, has_nsfw_concept = pipe.run_safety_checker(image, device, text_embeddings.dtype) 

630 

631 # 10. Convert to PIL 

632 if output_type == "pil": 

633 image = pipe.numpy_to_pil(image) 

634 

635 if not return_dict: 

636 return (image, has_nsfw_concept) 

637 if text_embeddings_opt is not None: 

638 return ROBINStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents, latents=latents, inner_latents=inner_latents,gt_patch=gt_patch,opt_acond=text_embeddings_opt[0],time=execution_time) 

639 else: 

640 return ROBINStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept, init_latents=init_latents, latents=latents, inner_latents=inner_latents,gt_patch=gt_patch,time=execution_time) 

641 

642def xn1_latents_3(pipe,latents,do_classifier_free_guidance,t 

643 ,text_embeddings,original_guidance_scale,optimized_guidance_scale,**extra_step_kwargs): 

644 latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents 

645 latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) 

646 noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 

647 if do_classifier_free_guidance: 

648 noise_pred_uncond, noise_pred_text1, noise_pred_text2 = noise_pred.chunk(3) 

649 noise_pred = noise_pred_uncond + original_guidance_scale * (noise_pred_text1 - noise_pred_uncond) + optimized_guidance_scale * (noise_pred_text2 - noise_pred_uncond) 

650 latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 

651 

652 return latents, noise_pred 

653 

654def xn1_latents(pipe,latents,do_classifier_free_guidance,t 

655 ,text_embeddings,guidance_scale,**extra_step_kwargs): 

656 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 

657 latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) 

658 noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 

659 if do_classifier_free_guidance: 

660 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 

661 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 

662 latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 

663 return latents, noise_pred # Make sure to return both values