File size: 4,950 Bytes
140387c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import Callable, Generator, Iterator, List, Optional, Union
import ctypes
from ctypes import (
    c_bool,
    c_char_p,
    c_int,
    c_int8,
    c_int32,
    c_uint8,
    c_uint32,
    c_size_t,
    c_float,
    c_double,
    c_void_p,
    POINTER,
    _Pointer,  # type: ignore
    Structure,
    Array,
)
import pathlib
import os
import sys

# Load the library
def _load_shared_library(lib_base_name: str):
    # Construct the paths to the possible shared library names
    _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
    # Searching for the library in the current directory under the name "libllama2" (default name
    # for llama2.cu) and "llama" (default name for this repo)
    _lib_paths: List[pathlib.Path] = []
    # Determine the file extension based on the platform
    if sys.platform.startswith("linux"):
        _lib_paths += [
            _base_path / f"lib{lib_base_name}.so",
        ]
    else:
        raise RuntimeError("Unsupported platform")

    if "LLAMA2_CU_LIB" in os.environ:
        lib_base_name = os.environ["LLAMA2_CU_LIB"]
        _lib = pathlib.Path(lib_base_name)
        _base_path = _lib.parent.resolve()
        _lib_paths = [_lib.resolve()]

    cdll_args = dict()  # type: ignore
    # Add the library directory to the DLL search path on Windows (if needed)

    # Try to load the shared library, handling potential errors
    for _lib_path in _lib_paths:
        if _lib_path.exists():
            try:
                return ctypes.CDLL(str(_lib_path), **cdll_args)
            except Exception as e:
                raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")

    raise FileNotFoundError(
        f"Shared library with base name '{lib_base_name}' not found"
    )

# Specify the base name of the shared library to load
_lib_base_name = "llama2"

# Load the library
_lib = _load_shared_library(_lib_base_name)


def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p:
    return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8'))

_lib.llama2_init.argtypes = [c_char_p, c_char_p]
_lib.llama2_init.restype = c_void_p

def llama2_free(ctx: c_void_p) -> None:
    _lib.llama2_free(ctx)

_lib.llama2_free.argtypes = [c_void_p]
_lib.llama2_free.restype = None

def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int:
    return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed)

_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int]
_lib.llama2_generate.restype = c_int

def llama2_get_last(ctx: c_void_p) -> bytes:
    return _lib.llama2_get_last(ctx)    # bytes or None

_lib.llama2_get_last.argtypes = [c_void_p]
_lib.llama2_get_last.restype = c_char_p

def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]:
    tokens = (c_int * (len(text) + 3))()
    n_tokens = (c_int * 1)()
    _lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens)
    return tokens[:n_tokens[0]]

_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)]
_lib.llama2_tokenize.restype = None

class Llama2:
    def __init__(
        self, 
        model_path: str,
        tokenizer_path: str='tokenizer.bin',
        n_ctx: int = 0,
        n_batch: int = 0) -> None:
        self.n_ctx = n_ctx
        self.n_batch = n_batch
        self.llama2_ctx = llama2_init(model_path, tokenizer_path)

    def tokenize(
        self, text: str, add_bos: bool = True, add_eos: bool = False
    ) -> List[int]:
        return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos)
    
    def __call__(
        self,
        prompt: str,
        max_tokens: int = 128,
        temperature: float = 0.8,
        top_p: float = 0.95,
        min_p: float = 0.05,
        typical_p: float = 1.0,
        logprobs: Optional[int] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        repeat_penalty: float = 1.1,
        top_k: int = 40,
        stream: bool = False,
        seed: Optional[int] = None,
    ) -> Iterator[str]:
        if seed is None:
            seed = 42
        ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed)
        if ret != 0:
            raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'")
        bytes_buffer = b''  # store generated bytes until decoded (in case of multi-byte characters)
        while True:
            result = llama2_get_last(self.llama2_ctx)
            if result is None:
                break
            bytes_buffer += result
            try:
                string = bytes_buffer.decode('utf-8')
            except UnicodeDecodeError:
                pass
            else:
                bytes_buffer = b''
                yield string