SEED-X-17B / src /models /detokenizer /adapter_modules.py
yuyingge
Add application file
590af54
raw
history blame
11.7 kB
import torch
import torch.nn as nn
import itertools
import torch.nn.functional as F
from typing import List
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline
class SDXLAdapter(nn.Module):
def __init__(self, unet, resampler, full_ft=False, vit_down=False) -> None:
super().__init__()
self.unet = unet
self.resampler = resampler
self.full_ft = full_ft
self.set_trainable_v2()
self.vit_down = vit_down
def set_trainable_v2(self):
self.resampler.requires_grad_(True)
adapter_parameters = []
if self.full_ft:
self.unet.requires_grad_(True)
adapter_parameters.extend(self.unet.parameters())
else:
self.unet.requires_grad_(False)
for name, module in self.unet.named_modules():
if name.endswith('to_k') or name.endswith('to_v'):
if module is not None:
adapter_parameters.extend(module.parameters())
self.adapter_parameters = adapter_parameters
def params_to_opt(self):
return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids):
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds}
noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample
# if noise is not None:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# else:
# loss = torch.tensor(0.0, device=noisy_latents)
return {'total_loss': loss, 'noise_pred': noise_pred}
def encode_image_embeds(self, image_embeds):
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
return image_embeds, pooled_image_embeds
@classmethod
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
model = cls(unet=unet, resampler=resampler, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
def init_pipe(self,
vae,
scheduler,
visual_encoder,
image_transform,
discrete_model=None,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None,
tokenizer_2=None,
text_encoder=None,
text_encoder_2=None,
vae=vae,
unet=self.unet,
scheduler=scheduler)
self.sdxl_pipe = sdxl_pipe #.to(self.device, dtype=self.dtype)
# print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder)
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
if discrete_model is not None:
self.discrete_model = discrete_model.to(self.device, dtype=self.dtype)
else:
self.discrete_model = None
self.image_transform = image_transform
@torch.inference_mode()
def get_image_embeds(self, image_pil=None, image_tensor=None, image_embeds=None, return_negative=True, image_size=448):
assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1
if image_pil is not None:
image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
if image_tensor is not None:
if return_negative:
image_tensor_neg = torch.zeros_like(image_tensor)
image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
image_embeds = self.visual_encoder(image_tensor)
elif return_negative:
image_tensor_neg = torch.zeros(1, 3, image_size, image_size).to(image_embeds.device, dtype=image_embeds.dtype)
image_embeds_neg = self.visual_encoder(image_tensor_neg)
if self.vit_down:
image_embeds_neg = image_embeds_neg.permute(0, 2, 1) # NLD -> NDL
image_embeds_neg = F.avg_pool1d(image_embeds_neg, kernel_size=4, stride=4)
image_embeds_neg = image_embeds_neg.permute(0, 2, 1)
image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0)
if self.discrete_model is not None:
image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds)
if return_negative:
image_embeds, image_embeds_neg = image_embeds.chunk(2)
pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2)
else:
image_embeds_neg = None
pooled_image_embeds_neg = None
return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg
def generate(self,
image_pil=None,
image_tensor=None,
image_embeds=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
input_image_size=448,
**kwargs):
if image_pil is not None:
assert isinstance(image_pil, Image.Image)
image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds(
image_pil=image_pil,
image_tensor=image_tensor,
image_embeds=image_embeds,
return_negative=True,
image_size=input_image_size,
)
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sdxl_pipe(
prompt_embeds=image_prompt_embeds,
negative_prompt_embeds=uncond_image_prompt_embeds,
pooled_prompt_embeds=pooled_image_prompt_embeds,
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images
class SDXLAdapterWithLatentImage(SDXLAdapter):
def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False, vit_down=False) -> None:
nn.Module.__init__(self)
self.unet = unet
self.resampler = resampler
self.full_ft = full_ft
if not set_trainable_late:
self.set_trainable()
self.vit_down = vit_down
def set_trainable(self):
self.resampler.requires_grad_(True)
adapter_parameters = []
in_channels = 8
out_channels = self.unet.conv_in.out_channels
self.unet.register_to_config(in_channels=in_channels)
self.unet.requires_grad_(False)
with torch.no_grad():
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
self.unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
self.unet.conv_in = new_conv_in
self.unet.conv_in.requires_grad_(True)
if self.full_ft:
self.unet.requires_grad_(True)
adapter_parameters.extend(self.unet.parameters())
else:
adapter_parameters.extend(self.unet.conv_in.parameters())
for name, module in self.unet.named_modules():
if name.endswith('to_k') or name.endswith('to_v'):
if module is not None:
adapter_parameters.extend(module.parameters())
self.adapter_parameters = adapter_parameters
@classmethod
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs):
model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
if set_trainable_late:
model.set_trainable()
return model
def init_pipe(self,
vae,
scheduler,
visual_encoder,
image_transform,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
tokenizer=None,
tokenizer_2=None,
text_encoder=None,
text_encoder_2=None,
vae=vae,
unet=self.unet,
scheduler=scheduler,
)
self.sdxl_pipe = sdxl_pipe
self.sdxl_pipe.to(device, dtype=dtype)
self.discrete_model = None
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
self.image_transform = image_transform
def generate(self,
image_pil=None,
image_tensor=None,
image_embeds=None,
latent_image=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
input_image_size=448,
**kwargs):
if image_pil is not None:
assert isinstance(image_pil, Image.Image)
image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds(
image_pil=image_pil,
image_tensor=image_tensor,
image_embeds=image_embeds,
return_negative=True,
image_size=input_image_size,
)
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sdxl_pipe(
image=latent_image,
prompt_embeds=image_prompt_embeds,
negative_prompt_embeds=uncond_image_prompt_embeds,
pooled_prompt_embeds=pooled_image_prompt_embeds,
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images