from flask import Flask, request, jsonify, render_template_string from vllm import LLM, SamplingParams from langchain_community.cache import GPTCache import torch app = Flask(__name__) # Verificar si hay una GPU disponible, si no usar la CPU device = "cuda" if torch.cuda.is_available() else "cpu" # Inicializar los modelos con el dispositivo adecuado (GPU o CPU) try: modelos = { "facebook/opt-125m": LLM(model="facebook/opt-125m", device=device), "llama-3.2-1B": LLM(model="Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device), "gpt2": LLM(model="gpt2", device=device) } except KeyError as e: print(f"Error al inicializar el modelo con {device}: {e}") modelos = {} # Verificar si los modelos fueron correctamente inicializados if not modelos: print("Error: No se pudo inicializar ningún modelo.") exit(1) # Configuración de caché para los modelos caches = { nombre: GPTCache(modelo, max_size=1000) for nombre, modelo in modelos.items() } # Parámetros de muestreo para la generación de texto sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Código HTML para la documentación de la API html_code_docs = """ Documentación de la API

API de Generación de Texto

Endpoints

""" # Código HTML para la interfaz del chatbot html_code_chatbot = """ Chatbot

Chatbot

""" @app.route('/generate', methods=['POST']) def generate(): data = request.get_json() prompts = data.get('prompts', []) modelo_seleccionado = data.get('modelo', "facebook/opt-125m") if modelo_seleccionado not in modelos: return jsonify({"error": "Modelo no encontrado"}), 404 outputs = caches[modelo_seleccionado].generate(prompts, sampling_params) results = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text results.append({ 'prompt': prompt, 'generated_text': generated_text }) return jsonify(results) @app.route('/modelos', methods=['GET']) def get_modelos(): return jsonify({"modelos": list(modelos.keys())}) @app.route('/docs', methods=['GET']) def docs(): return render_template_string(html_code_docs) @app.route('/chatbot', methods=['POST']) def chatbot(): data = request.get_json() mensaje = data.get('mensaje', '') modelo_seleccionado = data.get('modelo', "facebook/opt-125m") if modelo_seleccionado not in modelos: return jsonify({"error": "Modelo no encontrado"}), 404 outputs = caches[modelo_seleccionado].generate([mensaje], sampling_params) respuesta = outputs[0].outputs[0].text return jsonify({"respuesta": respuesta}) @app.route('/chat', methods=['GET']) def chat(): return render_template_string(html_code_chatbot) if __name__ == '__main__': # Asegurar que el servidor solo arranca si los modelos fueron inicializados correctamente if modelos: app.run(host='0.0.0.0', port=7860) else: print("Error: No se pudieron cargar los modelos. El servidor no se iniciará.")