File size: 3,659 Bytes
714d948 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
from Perceptrix.streamer import TextStreamer
from utils import setup_device
import torch
import os
model_name = os.environ.get('CHAT_MODEL')
model_path = "models/CRYSTAL-chat" if model_name == None else model_name
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=True)
device = setup_device()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
config=config,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offloads",
quantization_config=bnb_config if str(device) != "cpu" else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer = tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
streamer = TextStreamer(tokenizer, skip_prompt=True,
skip_special_tokens=True, save_file="reply.txt")
def evaluate(
prompt='',
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
yield output.split("### Response:")[-1].strip()
def predict(
inputs,
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
):
now_prompt = inputs
response = evaluate(
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
)
for i in response:
print(i)
response = i
return response
instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."
def perceptrix(prompt):
prompt = instructions+"\n"+prompt
response = predict(
inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
)
spl_tokens = ["<|im_start|>", "<|im_end|>"]
clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
return response[len(clean_prompt):]
if __name__ == "__main__":
history = ""
while True:
user_input = input("User: ")
user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
result = perceptrix(history+user_input)
history += user_input + result + "<|im_end|>\n"
|