Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from pathlib import Path | |
from typing import Optional | |
import torch | |
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer | |
class Tokenizer: | |
"""Tokenizer for LLaMA.""" | |
def __init__(self, model_path: Path) -> None: | |
self.processor = SentencePieceProcessor(model_file=str(model_path)) | |
self.bos_id = self.processor.bos_id() | |
self.eos_id = self.processor.eos_id() | |
self.pad_id = self.processor.pad_id() | |
def vocab_size(self) -> int: | |
return self.processor.vocab_size() | |
def encode( | |
self, | |
string: str, | |
bos: bool = True, | |
eos: bool = False, | |
max_length: int = -1, | |
pad: bool = False, | |
device: Optional[torch.device] = None | |
) -> torch.Tensor: | |
tokens = self.processor.encode(string) | |
if bos: | |
tokens = [self.bos_id] + tokens | |
if eos: | |
tokens = tokens + [self.eos_id] | |
if max_length > 0: | |
tokens = tokens[:max_length] | |
if pad and len(tokens) < max_length: | |
tokens += [self.pad_id] * (max_length - len(tokens)) | |
return torch.tensor(tokens, dtype=torch.int, device=device) | |
def decode(self, tokens: torch.Tensor) -> str: | |
return self.processor.decode(tokens.tolist()) | |
def train(input: str, destination: str, vocab_size=32000) -> None: | |
model_prefix = os.path.join(destination, "tokenizer") | |
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size) | |