import argparse import torch import torch.nn as nn import quant from gptq import GPTQ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders import transformers from transformers import AutoTokenizer def load_quant(model = "Wizard-Vicuna-13B-Uncensored-GPTQ", checkpoint = "Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors", wbits = 4, groupsize=128, fused_mlp=True, eval=True, warmup_autotune=True): from transformers import LlamaConfig, LlamaForCausalLM config = LlamaConfig.from_pretrained(model) def noop(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = LlamaForCausalLM(config) torch.set_default_dtype(torch.float) if eval: model = model.eval() layers = find_layers(model) for name in ['lm_head']: if name in layers: del layers[name] quant.make_quant_linear(model, layers, wbits, groupsize) del layers print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load model.load_state_dict(safe_load(checkpoint), strict=False) else: model.load_state_dict(torch.load(checkpoint), strict=False) if eval: quant.make_quant_attn(model) quant.make_quant_norm(model) if fused_mlp: quant.make_fused_mlp(model) if warmup_autotune: quant.autotune_warmup_linear(model, transpose=not (eval)) if eval and fused_mlp: quant.autotune_warmup_fused(model) model.seqlen = 2048 print('Done.') return model model = load_quant() model.to(DEV) tokenizer = AutoTokenizer.from_pretrained("Wizard-Vicuna-13B-Uncensored-GPTQ", use_fast=False) input_ids = tokenizer.encode("TEXT PROMPT GOES HERE", return_tensors="pt").to(DEV) with torch.no_grad(): generated_ids = model.generate( input_ids, do_sample=True, min_length=50, max_length=200, top_p=0.99, temperature=0.8, ) print(tokenizer.decode([el.item() for el in generated_ids[0]]))