remove flash_attn dependency for macos / non-gpu machines
#9
by
ursnation
- opened
Found a solution in this thread to remove the dependency for flash_attn
on macos (non-gpu environments) and adapted it for this model.
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import os
from unittest.mock import patch
from transformers.dynamic_module_utils import get_imports
import torch
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
"""Work around for flash_attn on MiniCPM 2.6 code example"""
imports = get_imports(filename)
if not torch.cuda.is_available() and "flash_attn" in imports:
imports.remove("flash_attn")
return imports
model_name = 'openbmb/MiniCPM-V-2_6'
# create model
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map='mps', # mps for macos gpu cores
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
image = Image.open('xx.png').convert('RGB')
question = 'Can you give me the text from this image into json format?'
msgs = [{'role': 'user', 'content': [image, question]}]
# res = model.chat(
# image=None,
# msgs=msgs,
# tokenizer=tokenizer
# )
# print(res)
res = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
sampling=True,
stream=True
)
generated_text = ""
for new_text in res:
generated_text += new_text
print(new_text, flush=True, end='')
Cheers!
thank you very much
This comment has been hidden