Spaces:
Sleeping
Sleeping
File size: 4,950 Bytes
140387c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import Callable, Generator, Iterator, List, Optional, Union
import ctypes
from ctypes import (
c_bool,
c_char_p,
c_int,
c_int8,
c_int32,
c_uint8,
c_uint32,
c_size_t,
c_float,
c_double,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
Array,
)
import pathlib
import os
import sys
# Load the library
def _load_shared_library(lib_base_name: str):
# Construct the paths to the possible shared library names
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
# Searching for the library in the current directory under the name "libllama2" (default name
# for llama2.cu) and "llama" (default name for this repo)
_lib_paths: List[pathlib.Path] = []
# Determine the file extension based on the platform
if sys.platform.startswith("linux"):
_lib_paths += [
_base_path / f"lib{lib_base_name}.so",
]
else:
raise RuntimeError("Unsupported platform")
if "LLAMA2_CU_LIB" in os.environ:
lib_base_name = os.environ["LLAMA2_CU_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
_lib_paths = [_lib.resolve()]
cdll_args = dict() # type: ignore
# Add the library directory to the DLL search path on Windows (if needed)
# Try to load the shared library, handling potential errors
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args)
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# Specify the base name of the shared library to load
_lib_base_name = "llama2"
# Load the library
_lib = _load_shared_library(_lib_base_name)
def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p:
return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8'))
_lib.llama2_init.argtypes = [c_char_p, c_char_p]
_lib.llama2_init.restype = c_void_p
def llama2_free(ctx: c_void_p) -> None:
_lib.llama2_free(ctx)
_lib.llama2_free.argtypes = [c_void_p]
_lib.llama2_free.restype = None
def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int:
return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed)
_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int]
_lib.llama2_generate.restype = c_int
def llama2_get_last(ctx: c_void_p) -> bytes:
return _lib.llama2_get_last(ctx) # bytes or None
_lib.llama2_get_last.argtypes = [c_void_p]
_lib.llama2_get_last.restype = c_char_p
def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]:
tokens = (c_int * (len(text) + 3))()
n_tokens = (c_int * 1)()
_lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens)
return tokens[:n_tokens[0]]
_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)]
_lib.llama2_tokenize.restype = None
class Llama2:
def __init__(
self,
model_path: str,
tokenizer_path: str='tokenizer.bin',
n_ctx: int = 0,
n_batch: int = 0) -> None:
self.n_ctx = n_ctx
self.n_batch = n_batch
self.llama2_ctx = llama2_init(model_path, tokenizer_path)
def tokenize(
self, text: str, add_bos: bool = True, add_eos: bool = False
) -> List[int]:
return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos)
def __call__(
self,
prompt: str,
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
) -> Iterator[str]:
if seed is None:
seed = 42
ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed)
if ret != 0:
raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'")
bytes_buffer = b'' # store generated bytes until decoded (in case of multi-byte characters)
while True:
result = llama2_get_last(self.llama2_ctx)
if result is None:
break
bytes_buffer += result
try:
string = bytes_buffer.decode('utf-8')
except UnicodeDecodeError:
pass
else:
bytes_buffer = b''
yield string
|