import torch import torch.nn as nn from gptq import GPTQ import argparse from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders import quant import transformers from transformers import AutoTokenizer from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig from transformers.modeling_outputs import BaseModelOutputWithPast from typing import List, Optional, Tuple, Union from accelerate import cpu_offload_with_hook, load_checkpoint_in_model class Offload_LlamaModel(LlamaModel): def __init__(self, config: LlamaConfig): super().__init__(config) def cpu_offload(self, preload): hook = None for cpu_offloaded_model in self.layers[preload:]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, DEV, prev_module_hook=hook) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None for idx in range(len(self.layers)): decoder_layer = self.layers[idx] if output_hidden_states: all_hidden_states += (hidden_states, ) past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1], ) if output_attentions: all_self_attns += (layer_outputs[1], ) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states, ) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def load_quant(model, checkpoint, wbits, groupsize, pre_layer, fused_mlp=True, warmup_autotune=True): transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel from transformers import LlamaConfig, LlamaForCausalLM config = LlamaConfig.from_pretrained(model) def noop(*args, **kwargs): pass torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = LlamaForCausalLM(config) torch.set_default_dtype(torch.float) model = model.eval() layers = find_layers(model) for name in ['lm_head']: if name in layers: del layers[name] quant.make_quant_linear(model, layers, wbits, groupsize) print('Loading model ...') load_checkpoint_in_model(model, checkpoint, dtype='float16') model.seqlen = 2048 if eval: quant.make_quant_attn(model) quant.make_quant_norm(model) if fused_mlp: quant.make_fused_mlp(model) if warmup_autotune: quant.autotune_warmup_linear(model) if fused_mlp: quant.autotune_warmup_fused(model) for i in range(pre_layer): model.model.layers[i].to(DEV) model.model.embed_tokens.to(DEV) model.model.norm.to(DEV) model.lm_head.to(DEV) model.model.cpu_offload(pre_layer) print('Done.') return model if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('model', type=str, help='llama model to load') parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8], help='#bits to use for quantization') parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') parser.add_argument('--load', type=str, default='', help='Load quantized model.') parser.add_argument('--text', type=str, help='input text') parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.') parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.') parser.add_argument('--top_p', type=float, default=0.95, help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.') parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.') parser.add_argument('--pre_layer', type=int, default=50, help='The number of layers to preload') args = parser.parse_args() if type(args.load) is not str: args.load = args.load.as_posix() model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer) tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) with torch.no_grad(): generated_ids = model.generate( input_ids, do_sample=True, min_length=args.min_length, max_length=args.max_length, top_p=args.top_p, temperature=args.temperature, ) print(tokenizer.decode([el.item() for el in generated_ids[0]]))