File size: 5,205 Bytes
3eb1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pathlib import Path
from typing import Tuple

import pyrallis
import torch
from accelerate import Accelerator
from torch import nn
from transformers import CLIPTokenizer

from models.neti_clip_text_encoder import NeTICLIPTextModel
from models.neti_mapper import NeTIMapper
from models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
from config import RunConfig


class CheckpointHandler:

    def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
        self.cfg = cfg
        self.placeholder_token_string = placeholder_token_string
        self.placeholder_token_id = placeholder_token_id
        self.save_root = save_root

    def save_model(self, text_encoder: NeTICLIPTextModel,
                   accelerator: Accelerator,
                   embeds_save_name: str,
                   mapper_save_name: str):
        self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
        self.save_mapper(text_encoder, mapper_save_name)

    def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str):
        """
        Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
        to take the place of our placeholder token.
        """
        learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
        learned_embeds = learned_embeds.detach().cpu()
        learned_embeds_dict = {self.placeholder_token_string: learned_embeds}
        torch.save(learned_embeds_dict, self.save_root / save_name)

    def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str):
        """ Save the mapper and config to be used at inference. """
        cfg_ = RunConfig(**self.cfg.__dict__.copy())
        state_dict = {
            "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
            "cfg": pyrallis.encode(cfg_),
            "encoder": text_encoder.text_model.embeddings.mapper.encoder
        }
        torch.save(state_dict, self.save_root / save_name)

    @staticmethod
    def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]:
        mapper_ckpt = torch.load(mapper_path, map_location="cpu")
        cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
        neti_mapper = NeTIMapper(output_dim=768,
                                 use_nested_dropout=cfg.model.use_nested_dropout,
                                 nested_dropout_prob=cfg.model.nested_dropout_prob,
                                 norm_scale=cfg.model.target_norm,
                                 use_positional_encoding=cfg.model.use_positional_encoding,
                                 num_pe_time_anchors=cfg.model.num_pe_time_anchors,
                                 pe_sigmas=cfg.model.pe_sigmas,
                                 output_bypass=cfg.model.output_bypass)
        neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
        encoder = mapper_ckpt['encoder']
        if isinstance(encoder, NeTIPositionalEncoding):
            encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
        elif isinstance(encoder, BasicEncoder):
            encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
            encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
        neti_mapper.encoder = encoder.cuda()
        neti_mapper.cuda()
        neti_mapper.eval()
        return cfg, neti_mapper

    @staticmethod
    def load_learned_embed_in_clip(learned_embeds_path: Path,
                                   text_encoder: NeTICLIPTextModel,
                                   tokenizer: CLIPTokenizer) -> Tuple[str, int]:
        loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

        # separate token and the embeds
        trained_tokens = list(loaded_learned_embeds.keys())
        embeds = list(loaded_learned_embeds.values())

        # cast to dtype of text_encoder
        dtype = text_encoder.get_input_embeddings().weight.dtype
        embeds = [e.to(dtype) for e in embeds]

        # add the tokens in tokenizer
        num_added_tokens = tokenizer.add_tokens(trained_tokens)
        if num_added_tokens == 0:
            raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
                             f"Please pass a different `token` that is not already in the tokenizer.")

        # resize the token embeddings
        text_encoder.resize_token_embeddings(len(tokenizer))

        # get the id for the token and assign the embeds
        placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]

        for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
            text_encoder.get_input_embeddings().weight.data[token_id] = embed

        assert len(trained_tokens) == 1, "Only one placeholder token is supported"
        placeholder_token = trained_tokens[0]
        placeholder_token_id = placeholder_token_ids[0]
        return placeholder_token, placeholder_token_id