File size: 3,619 Bytes
b212c94
 
 
53eee33
b212c94
 
 
 
 
 
 
 
53eee33
b212c94
 
 
 
 
 
 
36c9f0a
71df925
b212c94
 
 
 
 
 
 
 
 
36c9f0a
b212c94
 
 
 
 
 
 
e038371
b212c94
e038371
b212c94
36c9f0a
 
 
b212c94
185b262
b212c94
 
 
 
53eee33
b212c94
 
 
 
 
 
 
 
 
 
 
 
 
 
36c9f0a
 
 
53eee33
b212c94
53eee33
 
 
 
b212c94
185b262
e038371
 
 
b212c94
185b262
 
 
 
 
 
e038371
 
 
185b262
 
e038371
b212c94
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
from concurrent.futures import ThreadPoolExecutor, as_completed
import uvicorn
from dotenv import load_dotenv
from difflib import SequenceMatcher

load_dotenv()

app = FastAPI()

# Configuración de los modelos
models = [
    {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf"},
    {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-instruct-q2_k.gguf"},
    {"repo_id": "Ffftdtd5dtft/gemma-2-9b-it-Q2_K-GGUF", "filename": "gemma-2-9b-it-q2_k.gguf"},
    {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf"},
]

# Cargar modelos en memoria solo una vez
llms = [Llama.from_pretrained(repo_id=model['repo_id'], filename=model['filename']) for model in models]

class ChatRequest(BaseModel):
    message: str
    top_k: int = 50
    top_p: float = 0.95
    temperature: float = 0.7

def generate_chat_response(request, llm):
    try:
        user_input = normalize_input(request.message)
        response = llm.create_chat_completion(
            messages=[{"role": "user", "content": user_input}],
            top_k=request.top_k,
            top_p=request.top_p,
            temperature=request.temperature
        )
        reply = response['choices'][0]['message']['content']
        return {"response": reply, "literal": user_input}
    except Exception as e:
        return {"response": f"Error: {str(e)}", "literal": user_input}

def normalize_input(input_text):
    return input_text.strip()

def select_best_response(responses, request):
    coherent_responses = filter_by_coherence(responses, request)
    best_response = filter_by_similarity(coherent_responses)
    return best_response

def filter_by_coherence(responses, request):
    # Implementa aquí un filtro de coherencia si es necesario
    return responses

def filter_by_similarity(responses):
    responses.sort(key=len, reverse=True)
    best_response = responses[0]
    for i in range(1, len(responses)):
        ratio = SequenceMatcher(None, best_response, responses[i]).ratio()
        if ratio < 0.9:
            best_response = responses[i]
            break
    return best_response

@app.post("/generate_chat")
async def generate_chat(request: ChatRequest):
    if not request.message.strip():
        raise HTTPException(status_code=400, detail="The message cannot be empty.")
    
    with ThreadPoolExecutor(max_workers=None) as executor:
        futures = [executor.submit(generate_chat_response, request, llm) for llm in llms]
        responses = []
        for future in as_completed(futures):
            response = future.result()
            responses.append(response)
    
    # Verifica si alguna respuesta contiene un error y maneja el error si es necesario
    if any("Error" in response['response'] for response in responses):
        error_response = next(response for response in responses if "Error" in response['response'])
        raise HTTPException(status_code=500, detail=error_response['response'])
    
    # Extrae las respuestas y las entradas literales
    response_texts = [resp['response'] for resp in responses]
    literal_inputs = [resp['literal'] for resp in responses]
    
    # Selecciona la mejor respuesta
    best_response = select_best_response(response_texts, request)
    
    return {
        "best_response": best_response,
        "all_responses": response_texts,
        "literal_inputs": literal_inputs
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)