File size: 2,362 Bytes
320d690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc80749
320d690
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse

import torch
import torch.nn as nn
import quant

from gptq import GPTQ
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
import transformers
from transformers import AutoTokenizer


def load_quant(model = "Wizard-Vicuna-13B-Uncensored-GPTQ", 
               checkpoint = "Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors", 
               wbits = 4,
               groupsize=128, 
               fused_mlp=True, eval=True, warmup_autotune=True):
  from transformers import LlamaConfig, LlamaForCausalLM
  config = LlamaConfig.from_pretrained(model)

  def noop(*args, **kwargs):
      pass

  torch.nn.init.kaiming_uniform_ = noop
  torch.nn.init.uniform_ = noop
  torch.nn.init.normal_ = noop

  torch.set_default_dtype(torch.half)
  transformers.modeling_utils._init_weights = False
  torch.set_default_dtype(torch.half)
  model = LlamaForCausalLM(config)
  torch.set_default_dtype(torch.float)
  if eval:
      model = model.eval()
  layers = find_layers(model)
  for name in ['lm_head']:
      if name in layers:
          del layers[name]
  quant.make_quant_linear(model, layers, wbits, groupsize)

  del layers

  print('Loading model ...')
  if checkpoint.endswith('.safetensors'):
      from safetensors.torch import load_file as safe_load
      model.load_state_dict(safe_load(checkpoint), strict=False)
  else:
      model.load_state_dict(torch.load(checkpoint), strict=False)

  if eval:
      quant.make_quant_attn(model)
      quant.make_quant_norm(model)
      if fused_mlp:
          quant.make_fused_mlp(model)
  if warmup_autotune:
      quant.autotune_warmup_linear(model, transpose=not (eval))
      if eval and fused_mlp:
          quant.autotune_warmup_fused(model)
  model.seqlen = 2048
  print('Done.')

  return model

model = load_quant()
model.to(DEV)
tokenizer = AutoTokenizer.from_pretrained("Wizard-Vicuna-13B-Uncensored-GPTQ", use_fast=False)
input_ids = tokenizer.encode("TEXT PROMPT GOES HERE", return_tensors="pt").to(DEV)

with torch.no_grad():
    generated_ids = model.generate(
        input_ids,
        do_sample=True,
        min_length=50,
        max_length=200,
        top_p=0.99,
        temperature=0.8,
    )
print(tokenizer.decode([el.item() for el in generated_ids[0]]))