File size: 1,884 Bytes
458ae42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import json
import logging
import textwrap
import uuid

from ollama import Client

from modules.presets import i18n

from ..index_func import construct_index
from ..utils import count_token
from .base_model import BaseLLMModel


class OllamaClient(BaseLLMModel):
    def __init__(self, model_name, user_name="", ollama_host="", backend_model="") -> None:
        super().__init__(model_name=model_name, user=user_name)
        self.backend_model = backend_model
        self.ollama_host = ollama_host
        self.update_token_limit()

    def get_model_list(self):
        client = Client(host=self.ollama_host)
        return client.list()

    def update_token_limit(self):
        lower_model_name = self.backend_model.lower()
        if "mistral" in lower_model_name:
            self.token_upper_limit = 8*1024
        elif "gemma" in lower_model_name:
            self.token_upper_limit = 8*1024
        elif "codellama" in lower_model_name:
            self.token_upper_limit = 4*1024
        elif "llama2-chinese" in lower_model_name:
            self.token_upper_limit = 4*1024
        elif "llama2" in lower_model_name:
            self.token_upper_limit = 4*1024
        elif "mixtral" in lower_model_name:
            self.token_upper_limit = 32*1024
        elif "llava" in lower_model_name:
            self.token_upper_limit = 4*1024

    def get_answer_stream_iter(self):
        if self.backend_model == "":
            return i18n("请先选择Ollama后端模型\n\n")
        client = Client(host=self.ollama_host)
        response = client.chat(model=self.backend_model, messages=self.history,stream=True)
        partial_text = ""
        for i in response:
            response = i['message']['content']
            partial_text += response
            yield partial_text
        self.all_token_counts[-1] = count_token(partial_text)
        yield partial_text