vllmxd / app.py
Hjgugugjhuhjggg's picture
Update app.py
5f8cec1 verified
from flask import Flask, request, jsonify, render_template_string
from vllm import LLM, SamplingParams
from langchain_community.cache import GPTCache
import torch
app = Flask(__name__)
# Verificar si hay una GPU disponible, si no usar la CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Inicializar los modelos con el dispositivo adecuado (GPU o CPU)
try:
modelos = {
"facebook/opt-125m": LLM(model="facebook/opt-125m", device=device),
"llama-3.2-1B": LLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device),
"gpt2": LLM(model="gpt2", device=device)
}
except KeyError as e:
print(f"Error al inicializar el modelo con {device}: {e}")
modelos = {}
# Verificar si los modelos fueron correctamente inicializados
if not modelos:
print("Error: No se pudo inicializar ning煤n modelo.")
exit(1)
# Configuraci贸n de cach茅 para los modelos
caches = {
nombre: GPTCache(modelo, max_size=1000)
for nombre, modelo in modelos.items()
}
# Par谩metros de muestreo para la generaci贸n de texto
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# C贸digo HTML para la documentaci贸n de la API
html_code_docs = """
<!DOCTYPE html>
<html>
<head>
<title>Documentaci贸n de la API</title>
</head>
<body>
<h1>API de Generaci贸n de Texto</h1>
<h2>Endpoints</h2>
<ul>
<li>
<h3>Generar texto</h3>
<p>M茅todo: POST</p>
<p>Ruta: /generate</p>
<p>Par谩metros:</p>
<ul>
<li>prompts: Lista de prompts para generar texto</li>
<li>modelo: Nombre del modelo a utilizar</li>
</ul>
<p>Ejemplo:</p>
<pre>curl -X POST -H "Content-Type: application/json" -d '{"prompts": ["Hola, c贸mo est谩s?"], "modelo": "facebook/opt-125m"}' http://localhost:5000/generate</pre>
</li>
<li>
<h3>Obtener lista de modelos</h3>
<p>M茅todo: GET</p>
<p>Ruta: /modelos</p>
<p>Ejemplo:</p>
<pre>curl -X GET http://localhost:5000/modelos</pre>
</li>
<li>
<h3>Chatbot</h3>
<p>M茅todo: POST</p>
<p>Ruta: /chatbot</p>
<p>Par谩metros:</p>
<ul>
<li>mensaje: Mensaje para el chatbot</li>
<li>modelo: Nombre del modelo a utilizar</li>
</ul>
<p>Ejemplo:</p>
<pre>curl -X POST -H "Content-Type: application/json" -d '{"mensaje": "Hola, c贸mo est谩s?", "modelo": "facebook/opt-125m"}' http://localhost:5000/chatbot</pre>
</li>
</ul>
</body>
</html>
"""
# C贸digo HTML para la interfaz del chatbot
html_code_chatbot = """
<!DOCTYPE html>
<html>
<head>
<title>Chatbot</title>
</head>
<body>
<h1>Chatbot</h1>
<form id="chat-form">
<input type="text" id="mensaje" placeholder="Escribe un mensaje">
<button type="submit">Enviar</button>
</form>
<div id="respuestas"></div>
<script>
const form = document.getElementById('chat-form');
const mensajeInput = document.getElementById('mensaje');
const respuestasDiv = document.getElementById('respuestas');
form.addEventListener('submit', (e) => {
e.preventDefault();
const mensaje = mensajeInput.value;
fetch('/chatbot', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ mensaje })
})
.then((res) => res.json())
.then((data) => {
const respuesta = data.respuesta;
const respuestaHTML = `<p>T煤: ${mensaje}</p><p>Chatbot: ${respuesta}</p>`;
respuestasDiv.innerHTML += respuestaHTML;
mensajeInput.value = '';
});
});
</script>
</body>
</html>
"""
@app.route('/generate', methods=['POST'])
def generate():
data = request.get_json()
prompts = data.get('prompts', [])
modelo_seleccionado = data.get('modelo', "facebook/opt-125m")
if modelo_seleccionado not in modelos:
return jsonify({"error": "Modelo no encontrado"}), 404
outputs = caches[modelo_seleccionado].generate(prompts, sampling_params)
results = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
results.append({
'prompt': prompt,
'generated_text': generated_text
})
return jsonify(results)
@app.route('/modelos', methods=['GET'])
def get_modelos():
return jsonify({"modelos": list(modelos.keys())})
@app.route('/docs', methods=['GET'])
def docs():
return render_template_string(html_code_docs)
@app.route('/chatbot', methods=['POST'])
def chatbot():
data = request.get_json()
mensaje = data.get('mensaje', '')
modelo_seleccionado = data.get('modelo', "facebook/opt-125m")
if modelo_seleccionado not in modelos:
return jsonify({"error": "Modelo no encontrado"}), 404
outputs = caches[modelo_seleccionado].generate([mensaje], sampling_params)
respuesta = outputs[0].outputs[0].text
return jsonify({"respuesta": respuesta})
@app.route('/chat', methods=['GET'])
def chat():
return render_template_string(html_code_chatbot)
if __name__ == '__main__':
# Asegurar que el servidor solo arranca si los modelos fueron inicializados correctamente
if modelos:
app.run(host='0.0.0.0', port=7860)
else:
print("Error: No se pudieron cargar los modelos. El servidor no se iniciar谩.")