|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
### Quantization config |
|
``` |
|
|
|
"zero_point": true, |
|
"q_group_size": 128, |
|
"w_bit": 4, |
|
"version": "GEMM" |
|
|
|
``` |
|
|
|
### Script to AWQ quantization |
|
``` |
|
from awq import AutoAWQForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
model_path = 'PATH_TO Poro-34B' |
|
quant_path = 'Poro-34B-AWQ' |
|
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } |
|
|
|
# Load model |
|
model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
# Quantize |
|
model.quantize(tokenizer, quant_config=quant_config) |
|
|
|
# Save quantized model |
|
model.save_quantized(quant_path) |
|
tokenizer.save_pretrained(quant_path) |
|
``` |
|
|
|
### Generate |
|
``` |
|
from awq import AutoAWQForCausalLM |
|
from transformers import AutoTokenizer |
|
from transformers import GenerationConfig |
|
|
|
|
|
model_path = "gradjitta/Poro-34B-AWQ" |
|
|
|
|
|
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=True, trust_remote_code=False, safetensors=True) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) |
|
|
|
|
|
|
|
def generate(prompt): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].cuda() |
|
generation_output = model.generate( |
|
input_ids=input_ids, |
|
generation_config=GenerationConfig(pad_token_id=tokenizer.pad_token_id, temperature=1.0, top_p=0.99, top_k=50, num_beams=1, do_sample=True), |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=256 |
|
) |
|
for seq in generation_output.sequences: |
|
output = tokenizer.decode(seq) |
|
print(output) |
|
|
|
|
|
generate("Suomalainen runo elämästä:") |
|
``` |
|
|
|
### output |
|
``` |
|
Suomalainen runo elämästä: |
|
- se alkaa |
|
- sitten ei enää mikään riitä |
|
- se päättyy ja se alkaa</s> |
|
``` |
|
|
|
|
|
### Work supported by https://datacrunch.io/ |
|
#### Quantized by: gradjitta |