|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
import sys |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils |
|
|
|
|
|
@dataclass |
|
class AWQConfig: |
|
ckpt: str = field( |
|
default=None, |
|
metadata={ |
|
"help": "Load quantized model. The path to the local AWQ checkpoint." |
|
}, |
|
) |
|
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) |
|
groupsize: int = field( |
|
default=-1, |
|
metadata={"help": "Groupsize to use for quantization; default uses full row."}, |
|
) |
|
|
|
|
|
def load_awq_quantized(model_name, awq_config: AWQConfig, device): |
|
print("Loading AWQ quantized model...") |
|
|
|
try: |
|
from tinychat.utils import load_quant |
|
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp |
|
except ImportError as e: |
|
print(f"Error: Failed to import tinychat. {e}") |
|
print("Please double check if you have successfully installed AWQ") |
|
print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") |
|
sys.exit(-1) |
|
|
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, use_fast=False, trust_remote_code=True |
|
) |
|
|
|
def skip(*args, **kwargs): |
|
pass |
|
|
|
torch.nn.init.kaiming_uniform_ = skip |
|
torch.nn.init.kaiming_normal_ = skip |
|
torch.nn.init.uniform_ = skip |
|
torch.nn.init.normal_ = skip |
|
modeling_utils._init_weights = False |
|
|
|
torch.set_default_dtype(torch.half) |
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) |
|
|
|
if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): |
|
model = load_quant.load_awq_llama_fast( |
|
model, |
|
find_awq_ckpt(awq_config), |
|
awq_config.wbits, |
|
awq_config.groupsize, |
|
device, |
|
) |
|
make_quant_attn(model, device) |
|
make_quant_norm(model) |
|
make_fused_mlp(model) |
|
else: |
|
model = load_quant.load_awq_model( |
|
model, |
|
find_awq_ckpt(awq_config), |
|
awq_config.wbits, |
|
awq_config.groupsize, |
|
device, |
|
) |
|
return model, tokenizer |
|
|
|
|
|
def find_awq_ckpt(awq_config: AWQConfig): |
|
if Path(awq_config.ckpt).is_file(): |
|
return awq_config.ckpt |
|
|
|
for ext in ["*.pt", "*.safetensors"]: |
|
matched_result = sorted(Path(awq_config.ckpt).glob(ext)) |
|
if len(matched_result) > 0: |
|
return str(matched_result[-1]) |
|
|
|
print("Error: AWQ checkpoint not found") |
|
sys.exit(1) |
|
|