Better output in INT8
#149
by
aney
- opened
I tried whisper-large-v3 in INT8 and surprisingly the output was better. It transcribed things that FP16 and FP32 missed.
Try it out yourself:
pip install --upgrade transformers datasets[audio] accelerate bitsandbytes torch flash-attn soundfile
huggingface-cli login
mkdir whisper
huggingface-cli download openai/whisper-large-v3 --local-dir ~/whisper --local-dir-use-symlinks False
import torch
import time
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import os
import gc
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_dir = os.path.expanduser("~/whisper")
audio_file = "audio.mp3"
# Function to transcribe and time
def transcribe(model, processor, dtype, device, audio_file, use_device=True):
if use_device:
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=dtype,
device=device,
)
else:
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=dtype,
)
start_time = time.time()
result = pipe(audio_file)
end_time = time.time()
return result["text"], end_time - start_time
# Load and transcribe with fp16
print("Loading model in fp16 precision...")
model_fp16 = AutoModelForSpeechSeq2Seq.from_pretrained(
model_dir, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_safetensors=True
)
model_fp16.to(device)
processor = AutoProcessor.from_pretrained(model_dir)
print("Transcribing in fp16 precision...")
text_fp16, time_fp16 = transcribe(model_fp16, processor, torch.float16, device, audio_file)
print(f"Transcription (fp16): {text_fp16}")
print(f"Time taken (fp16): {time_fp16:.2f} seconds")
# Unload the fp16 model
del model_fp16
gc.collect()
torch.cuda.empty_cache()
# Load and transcribe with int8
print("Loading model in int8 precision...")
model_int8 = AutoModelForSpeechSeq2Seq.from_pretrained(
model_dir,
load_in_8bit=True,
low_cpu_mem_usage=True,
use_safetensors=True
)
print("Transcribing in int8 precision...")
text_int8, time_int8 = transcribe(model_int8, processor, torch.float16, device, audio_file, use_device=False)
print(f"Transcription (int8): {text_int8}")
print(f"Time taken (int8): {time_int8:.2f} seconds")
# Compare results
print(f"fp16 vs int8 time difference: {time_fp16 - time_int8:.2f} seconds")
You should create a larger set of audio samples (and ensure temperature=0) for a claim like this!