gptq_model / opt.py
ssaroya's picture
Upload 10 files
e5bdf53
raw
history blame
17.5 kB
import time
import torch
import torch.nn as nn
import argparse
import transformers
from gptq import GPTQ
from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
import quant
def get_opt(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings
return model
@torch.no_grad()
def opt_sequential(model, dataloader, dev):
print('Starting ...')
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.decoder.layers
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].to(dev))
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
print('Ready.')
quantizers = {}
for i in range(len(layers)):
layer = layers[i].to(dev)
subset = find_layers(layer)
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer = quant.Quantizer()
gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
for h in handles:
h.remove()
for name in subset:
print(f'Quantizing {name} in layer {i+1}/{len(layers)}...')
scale, zero, g_idx, _ = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
gptq[name].free()
for j in range(args.nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
del gptq
torch.cuda.empty_cache()
inps, outs = outs, inps
model.config.use_cache = use_cache
return quantizers
@torch.no_grad()
def opt_eval(model, testenc, dev):
print('Evaluating ...')
testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen
use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.decoder.layers
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
cache = {'i': 0, 'attention_mask': None}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
try:
model(batch)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
torch.cuda.empty_cache()
outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']
for i in range(len(layers)):
print(i)
layer = layers[i].to(dev)
if args.nearest:
subset = find_layers(layer)
for name in subset:
quantizer = quant.Quantizer()
quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)
W = subset[name].weight.data
quantizer.find_params(W, weight=True)
subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps
if model.model.decoder.final_layer_norm is not None:
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
if model.model.decoder.project_out is not None:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
model.lm_head = model.lm_head.to(dev)
testenc = testenc.to(dev)
nlls = []
for i in range(nsamples):
hidden_states = inps[i].unsqueeze(0)
if model.model.decoder.final_layer_norm is not None:
hidden_states = model.model.decoder.final_layer_norm(hidden_states)
if model.model.decoder.project_out is not None:
hidden_states = model.model.decoder.project_out(hidden_states)
lm_logits = model.lm_head(hidden_states)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(ppl.item())
model.config.use_cache = use_cache
# TODO: perform packing on GPU
def opt_pack(model, quantizers, wbits, groupsize):
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
quant.make_quant_linear(model, quantizers, wbits, groupsize)
qlayers = find_layers(model, [quant.QuantLinear])
print('Packing ...')
for name in qlayers:
print(name)
quantizers[name], scale, zero, g_idx = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
print('Done.')
return model
def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True):
from transformers import OPTConfig, OPTForCausalLM
config = OPTConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = OPTForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']:
if name in layers:
del layers[name]
quant.make_quant_linear(model, layers, wbits, groupsize)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
if warmup_autotune:
quant.autotune_warmup_linear(model, transpose=not (eval))
model.seqlen = model.config.max_position_embeddings
print('Done.')
return model
def opt_multigpu(model, gpus):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0])
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0])
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(gpus[0])
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(gpus[-1])
if hasattr(model.model.decoder, 'final_layer_norm') and model.model.decoder.final_layer_norm:
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(gpus[-1])
import copy
model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1])
cache = {'mask': None}
class MoveModule(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.dev = next(iter(self.module.parameters())).device
def forward(self, *inp, **kwargs):
inp = list(inp)
if inp[0].device != self.dev:
inp[0] = inp[0].to(self.dev)
if cache['mask'] is None or cache['mask'].device != self.dev:
cache['mask'] = kwargs['attention_mask'].to(self.dev)
kwargs['attention_mask'] = cache['mask']
tmp = self.module(*inp, **kwargs)
return tmp
layers = model.model.decoder.layers
pergpu = math.ceil(len(layers) / len(gpus))
for i in range(len(layers)):
layers[i] = MoveModule(layers[i].to(gpus[i // pergpu]))
model.gpus = gpus
def benchmark(model, input_ids, check=False):
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
torch.cuda.synchronize()
cache = {'past': None}
def clear_past(i):
def tmp(layer, inp, out):
if cache['past']:
cache['past'][i] = None
return tmp
for i, layer in enumerate(model.model.decoder.layers):
layer.register_forward_hook(clear_past(i))
print('Benchmarking ...')
if check:
loss = nn.CrossEntropyLoss()
tot = 0.
def sync():
if hasattr(model, 'gpus'):
for gpu in model.gpus:
torch.cuda.synchronize(gpu)
else:
torch.cuda.synchronize()
with torch.no_grad():
attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
times = []
for i in range(input_ids.numel()):
tick = time.time()
out = model(input_ids[:, i].reshape(-1), past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
sync()
times.append(time.time() - tick)
print(i, times[-1])
if check and i != input_ids.numel() - 1:
tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
cache['past'] = list(out.past_key_values)
del out
sync()
import numpy as np
print('Median:', np.median(times))
if check:
print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.')
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.')
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.')
parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.')
parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.')
parser.add_argument('--load', type=str, default='', help='Load quantized model.')
parser.add_argument('--benchmark', type=int, default=0, help='Number of tokens to use for benchmarking.')
parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic')
parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
args = parser.parse_args()
if type(args.load) is not str:
args.load = args.load.as_posix()
if args.load:
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
else:
model = get_opt(args.model)
model.eval()
dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen)
if not args.load and args.wbits < 16 and not args.nearest:
tick = time.time()
quantizers = opt_sequential(model, dataloader, DEV)
print(time.time() - tick)
if args.benchmark:
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
if len(gpus) > 1:
opt_multigpu(model, gpus)
else:
model = model.to(DEV)
if args.benchmark:
input_ids = next(iter(dataloader))[0][:, :args.benchmark]
benchmark(model, input_ids, check=args.check)
if args.eval:
datasets = ['wikitext2', 'ptb', 'c4']
if args.new_eval:
datasets = ['wikitext2', 'ptb-new', 'c4-new']
for dataset in datasets:
dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV)
if args.save:
opt_pack(model, quantizers, args.wbits, args.groupsize)
torch.save(model.state_dict(), args.save)
if args.save_safetensors:
opt_pack(model, quantizers, args.wbits, args.groupsize)
from safetensors.torch import save_file as safe_save
state_dict = model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
safe_save(state_dict, args.save_safetensors)