|
from pathlib import Path |
|
from typing import Tuple |
|
|
|
import pyrallis |
|
import torch |
|
from accelerate import Accelerator |
|
from torch import nn |
|
from transformers import CLIPTokenizer |
|
|
|
from models.neti_clip_text_encoder import NeTICLIPTextModel |
|
from models.neti_mapper import NeTIMapper |
|
from models.positional_encoding import NeTIPositionalEncoding, BasicEncoder |
|
from config import RunConfig |
|
|
|
|
|
class CheckpointHandler: |
|
|
|
def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path): |
|
self.cfg = cfg |
|
self.placeholder_token_string = placeholder_token_string |
|
self.placeholder_token_id = placeholder_token_id |
|
self.save_root = save_root |
|
|
|
def save_model(self, text_encoder: NeTICLIPTextModel, |
|
accelerator: Accelerator, |
|
embeds_save_name: str, |
|
mapper_save_name: str): |
|
self.save_learned_embeds(text_encoder, accelerator, embeds_save_name) |
|
self.save_mapper(text_encoder, mapper_save_name) |
|
|
|
def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str): |
|
""" |
|
Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference |
|
to take the place of our placeholder token. |
|
""" |
|
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id] |
|
learned_embeds = learned_embeds.detach().cpu() |
|
learned_embeds_dict = {self.placeholder_token_string: learned_embeds} |
|
torch.save(learned_embeds_dict, self.save_root / save_name) |
|
|
|
def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str): |
|
""" Save the mapper and config to be used at inference. """ |
|
cfg_ = RunConfig(**self.cfg.__dict__.copy()) |
|
state_dict = { |
|
"state_dict": text_encoder.text_model.embeddings.mapper.state_dict(), |
|
"cfg": pyrallis.encode(cfg_), |
|
"encoder": text_encoder.text_model.embeddings.mapper.encoder |
|
} |
|
torch.save(state_dict, self.save_root / save_name) |
|
|
|
@staticmethod |
|
def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]: |
|
mapper_ckpt = torch.load(mapper_path, map_location="cpu") |
|
cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg']) |
|
neti_mapper = NeTIMapper(output_dim=768, |
|
use_nested_dropout=cfg.model.use_nested_dropout, |
|
nested_dropout_prob=cfg.model.nested_dropout_prob, |
|
norm_scale=cfg.model.target_norm, |
|
use_positional_encoding=cfg.model.use_positional_encoding, |
|
num_pe_time_anchors=cfg.model.num_pe_time_anchors, |
|
pe_sigmas=cfg.model.pe_sigmas, |
|
output_bypass=cfg.model.output_bypass) |
|
neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True) |
|
encoder = mapper_ckpt['encoder'] |
|
if isinstance(encoder, NeTIPositionalEncoding): |
|
encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda()) |
|
elif isinstance(encoder, BasicEncoder): |
|
encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda() |
|
encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda() |
|
neti_mapper.encoder = encoder.cuda() |
|
neti_mapper.cuda() |
|
neti_mapper.eval() |
|
return cfg, neti_mapper |
|
|
|
@staticmethod |
|
def load_learned_embed_in_clip(learned_embeds_path: Path, |
|
text_encoder: NeTICLIPTextModel, |
|
tokenizer: CLIPTokenizer) -> Tuple[str, int]: |
|
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") |
|
|
|
|
|
trained_tokens = list(loaded_learned_embeds.keys()) |
|
embeds = list(loaded_learned_embeds.values()) |
|
|
|
|
|
dtype = text_encoder.get_input_embeddings().weight.dtype |
|
embeds = [e.to(dtype) for e in embeds] |
|
|
|
|
|
num_added_tokens = tokenizer.add_tokens(trained_tokens) |
|
if num_added_tokens == 0: |
|
raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. " |
|
f"Please pass a different `token` that is not already in the tokenizer.") |
|
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens] |
|
|
|
for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)): |
|
text_encoder.get_input_embeddings().weight.data[token_id] = embed |
|
|
|
assert len(trained_tokens) == 1, "Only one placeholder token is supported" |
|
placeholder_token = trained_tokens[0] |
|
placeholder_token_id = placeholder_token_ids[0] |
|
return placeholder_token, placeholder_token_id |
|
|