crystal-technologies's picture
Upload 1653 files
714d948
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig, GenerationConfig
from Perceptrix.streamer import TextStreamer
from utils import setup_device
import torch
import os
model_name = os.environ.get('CHAT_MODEL')
model_path = "models/CRYSTAL-chat" if model_name == None else model_name
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=True)
device = setup_device()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float32 if device == "cpu" else torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
config=config,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offloads",
quantization_config=bnb_config if str(device) != "cpu" else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer = tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
streamer = TextStreamer(tokenizer, skip_prompt=True,
skip_special_tokens=True, save_file="reply.txt")
def evaluate(
prompt='',
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
streamer=streamer,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
yield output.split("### Response:")[-1].strip()
def predict(
inputs,
temperature=0.4,
top_p=0.65,
top_k=35,
repetition_penalty=1.1,
max_new_tokens=512,
):
now_prompt = inputs
response = evaluate(
now_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, do_sample=True
)
for i in response:
print(i)
response = i
return response
instructions = "You are Comprehensive Robotics Yielding Sophisticated Technology And Logistics (CRYSTAL), an AI robot developed by Vatsal Dutt to be the most advanced robot in the world. You will be provided with prompts and other information to help the user."
def perceptrix(prompt):
prompt = instructions+"\n"+prompt
response = predict(
inputs=prompt, temperature=0.2, top_p=0.9, max_new_tokens=512
)
spl_tokens = ["<|im_start|>", "<|im_end|>"]
clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
return response[len(clean_prompt):]
if __name__ == "__main__":
history = ""
while True:
user_input = input("User: ")
user_input = "<|im_start|>User\n"+user_input+"<|im_end|>\n<|im_start|>CRYSTAL\n"
result = perceptrix(history+user_input)
history += user_input + result + "<|im_end|>\n"