myy97's picture
Upload folder using huggingface_hub
140387c
from typing import Callable, Generator, Iterator, List, Optional, Union
import ctypes
from ctypes import (
c_bool,
c_char_p,
c_int,
c_int8,
c_int32,
c_uint8,
c_uint32,
c_size_t,
c_float,
c_double,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
Array,
)
import pathlib
import os
import sys
# Load the library
def _load_shared_library(lib_base_name: str):
# Construct the paths to the possible shared library names
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
# Searching for the library in the current directory under the name "libllama2" (default name
# for llama2.cu) and "llama" (default name for this repo)
_lib_paths: List[pathlib.Path] = []
# Determine the file extension based on the platform
if sys.platform.startswith("linux"):
_lib_paths += [
_base_path / f"lib{lib_base_name}.so",
]
else:
raise RuntimeError("Unsupported platform")
if "LLAMA2_CU_LIB" in os.environ:
lib_base_name = os.environ["LLAMA2_CU_LIB"]
_lib = pathlib.Path(lib_base_name)
_base_path = _lib.parent.resolve()
_lib_paths = [_lib.resolve()]
cdll_args = dict() # type: ignore
# Add the library directory to the DLL search path on Windows (if needed)
# Try to load the shared library, handling potential errors
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args)
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# Specify the base name of the shared library to load
_lib_base_name = "llama2"
# Load the library
_lib = _load_shared_library(_lib_base_name)
def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p:
return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8'))
_lib.llama2_init.argtypes = [c_char_p, c_char_p]
_lib.llama2_init.restype = c_void_p
def llama2_free(ctx: c_void_p) -> None:
_lib.llama2_free(ctx)
_lib.llama2_free.argtypes = [c_void_p]
_lib.llama2_free.restype = None
def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int:
return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed)
_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int]
_lib.llama2_generate.restype = c_int
def llama2_get_last(ctx: c_void_p) -> bytes:
return _lib.llama2_get_last(ctx) # bytes or None
_lib.llama2_get_last.argtypes = [c_void_p]
_lib.llama2_get_last.restype = c_char_p
def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]:
tokens = (c_int * (len(text) + 3))()
n_tokens = (c_int * 1)()
_lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens)
return tokens[:n_tokens[0]]
_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)]
_lib.llama2_tokenize.restype = None
class Llama2:
def __init__(
self,
model_path: str,
tokenizer_path: str='tokenizer.bin',
n_ctx: int = 0,
n_batch: int = 0) -> None:
self.n_ctx = n_ctx
self.n_batch = n_batch
self.llama2_ctx = llama2_init(model_path, tokenizer_path)
def tokenize(
self, text: str, add_bos: bool = True, add_eos: bool = False
) -> List[int]:
return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos)
def __call__(
self,
prompt: str,
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
min_p: float = 0.05,
typical_p: float = 1.0,
logprobs: Optional[int] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
) -> Iterator[str]:
if seed is None:
seed = 42
ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed)
if ret != 0:
raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'")
bytes_buffer = b'' # store generated bytes until decoded (in case of multi-byte characters)
while True:
result = llama2_get_last(self.llama2_ctx)
if result is None:
break
bytes_buffer += result
try:
string = bytes_buffer.decode('utf-8')
except UnicodeDecodeError:
pass
else:
bytes_buffer = b''
yield string