Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
# import math | |
# from torchvision import transforms | |
import os | |
# from timm.models import create_model | |
from typing import Any, Dict, List, Optional, Union | |
from transformers import LlamaTokenizer | |
from diffusers import DiffusionPipeline | |
# from torchvision.transforms.functional import pil_to_tensor | |
# import torch | |
from PIL import Image | |
from torchvision import transforms | |
# from qformer.qformer_quantizer import Blip2QformerQuantizer | |
# from diffusers import StableUnCLIPImg2ImgPipeline | |
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline | |
WEIGHTS_NAME = 'seed_quantizer.pt' | |
DIFFUSION_NAME = 'diffusion_model' | |
class ImageTokenizer(nn.Module): | |
def __init__(self, | |
model_path, | |
diffusion_model_path=None, | |
load_diffusion=False, | |
image_size=224, | |
device='cuda', | |
fp16=True, | |
**kwargs): | |
super().__init__() | |
from .seed_qformer.qformer_quantizer import Blip2QformerQuantizer | |
model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path, | |
vit_precision='fp16' if fp16 else 'fp32', | |
**kwargs).eval() | |
if diffusion_model_path is not None and load_diffusion: | |
# diffusion_model = DiffusionPipeline.from_pretrained(diffusion_model_path, | |
# torch_dtype=torch.float16 if fp16 else torch.float32) | |
diffusion_model = StableUnCLIPImg2ImgPipeline.from_pretrained(diffusion_model_path, | |
torch_dtype=torch.float16 if fp16 else torch.float32) | |
self.diffusion_model = diffusion_model.to(device) | |
else: | |
self.diffusion_model = None | |
model = model.to(device) | |
processor = transforms.Compose([ | |
transforms.Resize((image_size, image_size), interpolation=3), | |
# transforms.Resize(image_size, interpolation=3), | |
# transforms.CenterCrop(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) | |
]) | |
if fp16: | |
model = model.half() | |
shape_latents = torch.Size([1, 4, 96, 96]) | |
self.latents = torch.randn(shape_latents, generator=None, device=device, dtype=torch.float16, layout=torch.strided) | |
shape_noise = torch.Size([1, 1024]) | |
self.noise = torch.randn(shape_noise, generator=None, device=device, dtype=torch.float16, layout=torch.strided) | |
self.model = model | |
self.processor = processor | |
self.device = device | |
self.fp16 = fp16 | |
def __len__(self): | |
return self.model.n_embed | |
def encode(self, image_torch): | |
'''Convert a batch of img to code | |
Args: | |
model: The tokenizer model. | |
img: [b, c, h, w] | |
''' | |
if len(image_torch.shape) == 3: | |
image_torch = image_torch.unsqueeze(0) | |
# img = image_torch.to(self.device) | |
img = image_torch | |
if self.fp16: | |
img = img.half() | |
with torch.no_grad(): | |
id, _ = self.model.get_codebook_indices(img) | |
return id.view(img.shape[0], -1) | |
def decode(self, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20): | |
image_embeds = self.model.get_codebook_entry(indices) | |
# image = self.diffusion_model(image_embeds=image_embed, | |
# noise_level=0, | |
# num_inference_steps=20, | |
# latents=self.latents, | |
# noise=self.noise).images | |
if negative_indices is not None: | |
assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices' | |
negative_image_embeds = self.model.get_codebook_entry(negative_indices) | |
else: | |
negative_image_embeds = None | |
image = self.diffusion_model( | |
image_embeds=image_embeds, | |
negative_image_embeds=negative_image_embeds, | |
guidance_scale=guidance_scale, | |
noise_level=0, | |
num_inference_steps=num_inference_steps, | |
latents=self.latents, | |
).images | |
return image | |
class SeedLlamaTokenizer(LlamaTokenizer): | |
def __init__(self, | |
vocab_file, | |
unk_token="<unk>", | |
bos_token="<s>", | |
eos_token="</s>", | |
pad_token=None, | |
sp_model_kwargs: Optional[Dict[str, Any]] = None, | |
add_bos_token=True, | |
add_eos_token=False, | |
clean_up_tokenization_spaces=False, | |
device='cuda', | |
fp16=True, | |
load_diffusion=False, | |
encoder_url=None, | |
diffusion_path=None, | |
**kwargs): | |
super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token, | |
clean_up_tokenization_spaces, **kwargs) | |
self.device = device | |
self.fp16 = fp16 | |
self.pad_token = self.unk_token | |
self.load_diffusion = load_diffusion | |
self.encoder_url = encoder_url | |
self.diffusion_path = diffusion_path | |
self.load_image_tokenizer() | |
def load_image_tokenizer(self): | |
if not hasattr(self, '_image_tokenizer'): | |
if self.encoder_url is not None: | |
model_path = self.encoder_url | |
else: | |
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path) | |
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME) | |
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME) | |
# diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip' | |
self._image_tokenizer = ImageTokenizer(model_path=model_path, | |
diffusion_model_path=self.diffusion_path, | |
load_diffusion=self.load_diffusion, | |
device=self.device, | |
fp16=self.fp16) | |
def image_tokenizer(self): | |
if not hasattr(self, '_image_tokenizer'): | |
if self.encoder_url is not None: | |
model_path = self.encoder_url | |
else: | |
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path) | |
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME) | |
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME) | |
# diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip' | |
self._image_tokenizer = ImageTokenizer(model_path=model_path, | |
diffusion_model_path=self.diffusion_path, | |
load_diffusion=self.load_diffusion, | |
device=self.device, | |
fp16=self.fp16) | |
return self._image_tokenizer | |
def num_image_tokens(self): | |
return 8192 # self.image_tokenizer.num_tokens # allow not load | |
def to(self, device): | |
self.device = device | |
if hasattr(self, '_image_tokenizer'): | |
self._image_tokenizer.to(device=device) | |
def encode_image( | |
self, | |
image_path=None, | |
image_pil=None, | |
image_torch=None, | |
image_size: int = 224, | |
): | |
assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2 | |
# need_norm_to_1 = False | |
if image_path is not None: | |
image_pil = Image.open(image_path).convert('RGB') | |
if image_pil is not None: | |
image_torch = self.image_tokenizer.processor(image_pil) | |
image_torch = image_torch.to(self.device) | |
return self.image_tokenizer.encode(image_torch) | |
def decode_image(self, indices, negative_indices=None, guidance_scale=10): | |
indices = indices.to(self.device) | |
if negative_indices is not None: | |
negative_indices = negative_indices.to(self.device) | |
image = self.image_tokenizer.decode( | |
indices, | |
negative_indices=negative_indices, | |
guidance_scale=guidance_scale, | |
) | |
return image | |