File size: 3,659 Bytes
714d948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
from Perceptrix.streamer import TextStreamer
from utils import setup_device
import torch
import os

model_name = os.environ.get('CHAT_MODEL')

model_path = "models/CRYSTAL-chat" if model_name == None else model_name
config = AutoConfig.from_pretrained(
    model_path, trust_remote_code=True)

device = setup_device()

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
    config=config,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    offload_folder="offloads",
    quantization_config=bnb_config if str(device) != "cpu" else None,
)

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True,
)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer = tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

streamer = TextStreamer(tokenizer, skip_prompt=True,
                        skip_special_tokens=True, save_file="reply.txt")

def evaluate(
    prompt='',
    temperature=0.4,
    top_p=0.65,
    top_k=35,
    repetition_penalty=1.1,
    max_new_tokens=512,
    **kwargs,
):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        **kwargs,
    )

    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            streamer=streamer,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s, skip_special_tokens=True)
    yield output.split("### Response:")[-1].strip()


def predict(
    inputs,
    temperature=0.4,
    top_p=0.65,
    top_k=35,
    repetition_penalty=1.1,
    max_new_tokens=512,
):
    now_prompt = inputs

    response = evaluate(
        now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
    )

    for i in response:
        print(i)
        response = i

    return response


instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."

def perceptrix(prompt):
    prompt = instructions+"\n"+prompt
    response = predict(
        inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
    )
    spl_tokens = ["<|im_start|>", "<|im_end|>"]
    clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
    return response[len(clean_prompt):]


if __name__ == "__main__":
    history = ""
    while True:
        user_input = input("User: ")
        user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
        result = perceptrix(history+user_input)
        history += user_input + result + "<|im_end|>\n"