phi-3-5-Vision / app.py
Saad0KH's picture
Update app.py
c003dfa verified
from flask import Flask, request, jsonify ,send_file
from PIL import Image
import base64
import spaces
import requests
from loadimg import load_img
from io import BytesIO
import numpy as np
import uuid
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import subprocess
import logging
import json
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
app = Flask(__name__)
kwargs = {}
kwargs['torch_dtype'] = torch.bfloat16
models = {
"microsoft/Phi-3-vision-128k-instruct": AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
}
processors = {
"microsoft/Phi-3-vision-128k-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True)
}
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
def get_image_from_url(url):
try:
response = requests.get(url)
response.raise_for_status() # Vérifie les erreurs HTTP
img = Image.open(BytesIO(response.content))
return img
except Exception as e:
logging.error(f"Error fetching image from URL: {e}")
raise
# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
# Function to encode a PIL image to base64
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG") # Use PNG for compatibility with RGBA
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def get_image(image_data):
# Vérifie si l'image est en base64 ou URL
if image_data.startswith('http://') or image_data.startswith('https://'):
return get_image_from_url(image_data) # Télécharge l'image depuis l'URL
else:
return decode_image_from_base64(image_data) # Décode l'image base64
@spaces.GPU
def process_image(image, text_input=None, model_id="microsoft/Phi-3.5-vision-instruct"):
model = models[model_id]
processor = processors[model_id]
prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
image = image.convert("RGB")
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(**inputs,
max_new_tokens=1000,
eos_token_id=processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
return response
@app.route('/', methods=['GET'])
def welcome():
return "Welcome to Phi Vision API"
@app.route('/api/process', methods=['POST'])
def detect():
try:
data = request.json
image = data['image']
prompt = data['prompt']
image = get_image(image)
result = process_image(image,prompt)
# Remove ```json and ``` markers
if result.startswith("```json"):
result = result[7:] # Remove the leading ```json
if result.endswith("```"):
result = result[:-3] # Remove the trailing ```
# Convert the string result to a Python dictionary
try:
result_dict = json.loads(result)
except json.JSONDecodeError as e:
logging.error(f"JSON decoding error: {e}")
return jsonify({'error': 'Invalid JSON format in the response'}), 500
return jsonify({'result': result_dict})
except Exception as e:
logging.error(f"Error occurred: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860)