Qwen2-Audio-rkllm / run_rknn.py
happyme531's picture
Upload 20 files
2ef3e1d verified
import os
import time
import numpy as np
from rkllm_binding import *
from rknnlite.api.rknn_lite import RKNNLite
from transformers import WhisperFeatureExtractor
import signal
import cv2
import librosa
MODEL_PATH = "qwen.rkllm"
AUDIO_ENCODER_PATH = "audio_encoder.rknn"
handle = None
img_size = 448
# exit on ctrl-c
def signal_handler(signal, frame):
print("Ctrl-C pressed, exiting...")
global handle
if handle:
abort(handle)
destroy(handle)
exit(0)
signal.signal(signal.SIGINT, signal_handler)
# export RKLLM_LOG_LEVEL=1
os.environ["RKLLM_LOG_LEVEL"] = "1"
inference_count = 0
inference_start_time = 0
def result_callback(result, userdata, state):
global inference_start_time
global inference_count
if state == LLMCallState.RKLLM_RUN_NORMAL:
if inference_count == 0:
first_token_time = time.time()
print(f"Time to first token: {first_token_time - inference_start_time:.2f} seconds")
inference_count += 1
print(result.contents.text.decode(), end="", flush=True)
elif state == LLMCallState.RKLLM_RUN_FINISH:
print("\n\n(finished)")
elif state == LLMCallState.RKLLM_RUN_ERROR:
print("\nError occurred during LLM call")
feature_extractor = WhisperFeatureExtractor.from_pretrained(".")
# Initialize audio encoder
audio_encoder = RKNNLite(verbose=True)
model_size = os.path.getsize(AUDIO_ENCODER_PATH)
print(f"Start loading audio encoder model (size: {model_size / 1024 / 1024:.2f} MB)")
start_time = time.time()
audio_encoder.load_rknn(AUDIO_ENCODER_PATH)
end_time = time.time()
print(f"Audio encoder loaded in {end_time - start_time:.2f} seconds (speed: {model_size / (end_time - start_time) / 1024 / 1024:.2f} MB/s)")
audio_encoder.init_runtime()
# Initialize RKLLM
param = create_default_param()
param.model_path = MODEL_PATH.encode()
param.img_start = "<|audio_bos|>".encode()
param.img_end = "<|audio_eos|>".encode()
param.img_content = "<|AUDIO|>".encode()
param.max_context_len = 1024
extend_param = RKLLMExtendParam()
extend_param.base_domain_id = 1 # iommu domain 0 for audio encoder
param.extend_param = extend_param
model_size = os.path.getsize(MODEL_PATH)
print(f"Start loading language model (size: {model_size / 1024 / 1024:.2f} MB)")
start_time = time.time()
handle = init(param, result_callback)
end_time = time.time()
print(f"Language model loaded in {end_time - start_time:.2f} seconds (speed: {model_size / (end_time - start_time) / 1024 / 1024:.2f} MB/s)")
# audio embedding
audio_path = "glass-breaking.mp3"
print("Start inference...")
audio, _ = librosa.load(audio_path, sr=feature_extractor.sampling_rate)
feature_extractor_output = feature_extractor(
audio,
sampling_rate=feature_extractor.sampling_rate,
return_attention_mask=True,
padding="max_length"
)
print(feature_extractor_output.input_features.shape)
start_time = time.time()
audio_embeddings = audio_encoder.inference(inputs=[
feature_extractor_output.input_features.astype(np.float32),
feature_extractor_output.attention_mask.astype(np.float32)
], data_format="nhwc")[0].astype(np.float32)
end_time = time.time()
print(f"Audio encoder inference time: {end_time - start_time:.2f} seconds")
print(audio_embeddings.flags)
print(audio_embeddings.shape)
# Create input. RKLLM is stupid enough to hardcode the <image> tag for embedding.
prompt = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Audio 1: <image>
这是什么声音? <|im_end|>
<|im_start|>assistant
"""
# # # 2.56->3.25>2.41->10.2
# # image_embeddings = np.load("image_embeddings_pth_orig.npy")
# # image_embeddings = np.ascontiguousarray(image_embeddings, dtype=np.float32)
# # print(f"Loaded embeddings shape: {image_embeddings.shape}")
# # rkllm_input = create_rkllm_input(RKLLMInputType.RKLLM_INPUT_EMBED, embed=image_embeddings)
rkllm_input = create_rkllm_input(RKLLMInputType.RKLLM_INPUT_MULTIMODAL, prompt=prompt, image_embed=audio_embeddings)
# Create inference parameters
infer_param = RKLLMInferParam()
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE.value
# Run RKLLM
inference_start_time = time.time()
run(handle, rkllm_input, infer_param, None)
# Clean up
destroy(handle)