File size: 3,808 Bytes
d754e91
 
87a0e23
d754e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
966795b
87a0e23
 
966795b
 
 
 
49b832f
d754e91
4b2400e
d754e91
966795b
d754e91
87a0e23
d754e91
 
4b2400e
d754e91
966795b
d754e91
 
 
 
4b2400e
d754e91
966795b
d754e91
 
 
4b2400e
 
 
 
 
 
 
 
 
 
49b832f
966795b
 
 
4b2400e
 
d754e91
 
 
 
 
 
 
87a0e23
 
 
d754e91
 
87a0e23
 
d754e91
 
 
 
 
 
87a0e23
 
d754e91
 
 
 
 
 
 
 
0e92a92
c15d0e4
d754e91
87a0e23
b9929ef
 
 
7b14813
87a0e23
9279c83
 
 
 
 
 
 
 
87a0e23
 
 
 
 
 
 
966795b
49b832f
9279c83
87a0e23
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
import gc

import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

from .globals import Global


def get_device():
    if torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

    try:
        if torch.backends.mps.is_available():
            return "mps"
    except:  # noqa: E722
        pass


device = get_device()


def get_base_model():
    load_base_model()
    return Global.loaded_base_model


def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
    Global.model_has_been_used = True

    if Global.cached_lora_models:
        model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
        if model_from_cache:
            return model_from_cache

    if device == "cuda":
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights_name_or_path,
            torch_dtype=torch.float16,
            device_map={'': 0},  # ? https://github.com/tloen/alpaca-lora/issues/21
        )
    elif device == "mps":
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights_name_or_path,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        model = PeftModel.from_pretrained(
            get_base_model(),
            lora_weights_name_or_path,
            device_map={"": device},
        )

    model.config.pad_token_id = get_tokenizer().pad_token_id = 0
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    if not Global.load_8bit:
        model.half()  # seems to fix bugs for some users.

    model.eval()
    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)

    if Global.cached_lora_models:
        Global.cached_lora_models.set(lora_weights_name_or_path, model)

    return model


def get_tokenizer():
    load_base_model()
    return Global.loaded_tokenizer


def load_base_model():
    if Global.ui_dev_mode:
        return

    if Global.loaded_tokenizer is None:
        Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
            Global.base_model
        )
    if Global.loaded_base_model is None:
        if device == "cuda":
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model,
                load_in_8bit=Global.load_8bit,
                torch_dtype=torch.float16,
                # device_map="auto",
                device_map={'': 0},  # ? https://github.com/tloen/alpaca-lora/issues/21
            )
        elif device == "mps":
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
            )
        else:
            Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
                Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
            )

        Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
        Global.loaded_base_model.config.bos_token_id = 1
        Global.loaded_base_model.config.eos_token_id = 2


def clear_cache():
    gc.collect()

    # if not shared.args.cpu: # will not be running on CPUs anyway
    with torch.no_grad():
        torch.cuda.empty_cache()


def unload_models():
    del Global.loaded_base_model
    Global.loaded_base_model = None

    del Global.loaded_tokenizer
    Global.loaded_tokenizer = None

    Global.cached_lora_models.clear()

    clear_cache()

    Global.model_has_been_used = False


def unload_models_if_already_used():
    if Global.model_has_been_used:
        unload_models()