Spaces:
Sleeping
Sleeping
from typing import Callable, Generator, Iterator, List, Optional, Union | |
import ctypes | |
from ctypes import ( | |
c_bool, | |
c_char_p, | |
c_int, | |
c_int8, | |
c_int32, | |
c_uint8, | |
c_uint32, | |
c_size_t, | |
c_float, | |
c_double, | |
c_void_p, | |
POINTER, | |
_Pointer, # type: ignore | |
Structure, | |
Array, | |
) | |
import pathlib | |
import os | |
import sys | |
# Load the library | |
def _load_shared_library(lib_base_name: str): | |
# Construct the paths to the possible shared library names | |
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) | |
# Searching for the library in the current directory under the name "libllama2" (default name | |
# for llama2.cu) and "llama" (default name for this repo) | |
_lib_paths: List[pathlib.Path] = [] | |
# Determine the file extension based on the platform | |
if sys.platform.startswith("linux"): | |
_lib_paths += [ | |
_base_path / f"lib{lib_base_name}.so", | |
] | |
else: | |
raise RuntimeError("Unsupported platform") | |
if "LLAMA2_CU_LIB" in os.environ: | |
lib_base_name = os.environ["LLAMA2_CU_LIB"] | |
_lib = pathlib.Path(lib_base_name) | |
_base_path = _lib.parent.resolve() | |
_lib_paths = [_lib.resolve()] | |
cdll_args = dict() # type: ignore | |
# Add the library directory to the DLL search path on Windows (if needed) | |
# Try to load the shared library, handling potential errors | |
for _lib_path in _lib_paths: | |
if _lib_path.exists(): | |
try: | |
return ctypes.CDLL(str(_lib_path), **cdll_args) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") | |
raise FileNotFoundError( | |
f"Shared library with base name '{lib_base_name}' not found" | |
) | |
# Specify the base name of the shared library to load | |
_lib_base_name = "llama2" | |
# Load the library | |
_lib = _load_shared_library(_lib_base_name) | |
def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p: | |
return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8')) | |
_lib.llama2_init.argtypes = [c_char_p, c_char_p] | |
_lib.llama2_init.restype = c_void_p | |
def llama2_free(ctx: c_void_p) -> None: | |
_lib.llama2_free(ctx) | |
_lib.llama2_free.argtypes = [c_void_p] | |
_lib.llama2_free.restype = None | |
def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int: | |
return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed) | |
_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int] | |
_lib.llama2_generate.restype = c_int | |
def llama2_get_last(ctx: c_void_p) -> bytes: | |
return _lib.llama2_get_last(ctx) # bytes or None | |
_lib.llama2_get_last.argtypes = [c_void_p] | |
_lib.llama2_get_last.restype = c_char_p | |
def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]: | |
tokens = (c_int * (len(text) + 3))() | |
n_tokens = (c_int * 1)() | |
_lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens) | |
return tokens[:n_tokens[0]] | |
_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)] | |
_lib.llama2_tokenize.restype = None | |
class Llama2: | |
def __init__( | |
self, | |
model_path: str, | |
tokenizer_path: str='tokenizer.bin', | |
n_ctx: int = 0, | |
n_batch: int = 0) -> None: | |
self.n_ctx = n_ctx | |
self.n_batch = n_batch | |
self.llama2_ctx = llama2_init(model_path, tokenizer_path) | |
def tokenize( | |
self, text: str, add_bos: bool = True, add_eos: bool = False | |
) -> List[int]: | |
return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos) | |
def __call__( | |
self, | |
prompt: str, | |
max_tokens: int = 128, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
logprobs: Optional[int] = None, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
top_k: int = 40, | |
stream: bool = False, | |
seed: Optional[int] = None, | |
) -> Iterator[str]: | |
if seed is None: | |
seed = 42 | |
ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed) | |
if ret != 0: | |
raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'") | |
bytes_buffer = b'' # store generated bytes until decoded (in case of multi-byte characters) | |
while True: | |
result = llama2_get_last(self.llama2_ctx) | |
if result is None: | |
break | |
bytes_buffer += result | |
try: | |
string = bytes_buffer.decode('utf-8') | |
except UnicodeDecodeError: | |
pass | |
else: | |
bytes_buffer = b'' | |
yield string | |