import torch import torch.nn as nn import quant from gptq import GPTQ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders import transformers from transformers import AutoTokenizer class ModelInference: def __init__(self, model_name, load=None, wbits=16, groupsize=-1): self.model_name = model_name self.load = load self.wbits = wbits self.groupsize = groupsize if self.load: self.model = self.load_quant(self.model_name, self.load, self.wbits, self.groupsize) else: self.model = self.get_llama(self.model_name) self.model.eval() self.model.to(DEV) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False) def get_llama(model): def skip(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip from transformers import LlamaForCausalLM model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') model.seqlen = 2048 return model def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): from transformers import LlamaConfig, LlamaForCausalLM config = LlamaConfig.from_pretrained(model) def noop(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = LlamaForCausalLM(config) torch.set_default_dtype(torch.float) if eval: model = model.eval() layers = find_layers(model) for name in ['lm_head']: if name in layers: del layers[name] quant.make_quant_linear(model, layers, wbits, groupsize) del layers print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load model.load_state_dict(safe_load(checkpoint), strict=False) else: model.load_state_dict(torch.load(checkpoint), strict=False) if eval: quant.make_quant_attn(model) quant.make_quant_norm(model) if fused_mlp: quant.make_fused_mlp(model) if warmup_autotune: quant.autotune_warmup_linear(model, transpose=not (eval)) if eval and fused_mlp: quant.autotune_warmup_fused(model) model.seqlen = 2048 print('Done.') return model def generate_text(self, text, min_length=10, max_length=50, top_p=0.95, temperature=0.8): input_ids = self.tokenizer.encode(text, return_tensors="pt").to(DEV) with torch.no_grad(): generated_ids = self.model.generate( input_ids, do_sample=True, min_length=min_length, max_length=max_length, top_p=top_p, temperature=temperature, ) return self.tokenizer.decode([el.item() for el in generated_ids[0]])