|
--- |
|
language: |
|
- en |
|
license: apache-2.0 |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- mistral |
|
- trl |
|
base_model: unsloth/mistral-7b-v0.3-bnb-4bit |
|
--- |
|
|
|
# Uploaded model |
|
|
|
- **Developed by:** jingwang |
|
- **License:** apache-2.0 |
|
- **Finetuned from model :** unsloth/mistral-7b-v0.3-bnb-4bit |
|
|
|
This mistral model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
|
|
# install dependencies in google colab |
|
|
|
```shell |
|
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" |
|
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes |
|
``` |
|
|
|
# inference |
|
```python |
|
|
|
from unsloth import FastLanguageModel |
|
from typing import Dict, List, Tuple, Union, Any |
|
import pandas |
|
from tqdm import trange, tqdm |
|
import torch |
|
|
|
class FormatPrompt_context_QA(): |
|
'''format prompt class''' |
|
def __init__(self, eos_token:str='</s>') -> None: |
|
self.inputs = ['context','question'] # required input fields |
|
self.outputs = ['answer'] # for training, and model inference output fields |
|
self.eos_token = eos_token |
|
|
|
def __call__(self, instance: Dict[str, Any]) -> str: |
|
''' |
|
function call operator |
|
Args: |
|
instance: dictionary with keys: 'context', 'question', 'answer' |
|
Returns: |
|
prompt: formatted prompt |
|
''' |
|
return self.formatting_prompt_func(instance) |
|
|
|
def formatting_prompt_func(self, instance: dict) -> str: |
|
'''format prompt for domain specific QA |
|
note this is for fine-tuning pre-trained model, |
|
if starting with instuct tuned model, use `tokenizer.apply_chat_template(messages)` instead |
|
''' |
|
|
|
assert all([ item in instance.keys() for item in self.inputs ]), logging.info(f"instance must have {self.inputs}!") |
|
|
|
prompt = f"""<s> [INST] Answer following question based on Context: {str(instance["context"])}\ |
|
Question: {str(instance["question"])} \ |
|
Answer: [/INST]""" |
|
|
|
if 'answer' in instance: |
|
prompt += str(instance['answer']) + self.eos_token |
|
return prompt |
|
``` |
|
|
|
```python |
|
formatting_func = FormatPrompt_context_QA() |
|
|
|
# pull model from huggingface |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name = "jingwang/mistral_context_qa", |
|
max_seq_length = 2048, |
|
dtype = None, |
|
load_in_4bit = True, |
|
) |
|
|
|
|
|
FastLanguageModel.for_inference(model) |
|
|
|
example = {'question': 'What does the graph compare in terms of cumulative total return?', |
|
'context': 'the following graph shows a comparison, from january 1, 2019 through december 31, 2023, of the cumulative total return on our common stock, the nasdaq composite index and a group of all public companies sharing the same sic code as us, which is sic code 3711, “ motor vehicles and passenger car bodies ” ( motor vehicles and passenger car bodies public company group ). such returns are based on historical results and are not intended to suggest future performance. data for the nasdaq composite index and the motor vehicles and passenger car bodies public company group assumes an investment of $ 100 on january 1, 2019 and reinvestment of dividends. we have never declared or paid cash dividends on our common stock nor do we anticipate paying any such cash dividends in the foreseeable future. 31', |
|
'gold_answer': "The graph compares the cumulative total return from January 1, 2019, through December 31, 2023, of the company's common stock, the NASDAQ Composite Index, and a group of public companies with the same SIC code (3711 - Motor Vehicles and Passenger Car Bodies). The comparison assumes an initial investment of $100 on January 1, 2019, with reinvestment of dividends for the NASDAQ Composite Index and the Motor Vehicles and Passenger Car Bodies public company group.", |
|
} |
|
|
|
inputs = tokenizer([formatting_func(example)], return_tensors="pt", padding=False).to(model.device) |
|
input_length = inputs.input_ids.shape[-1] |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs, |
|
do_sample=False, |
|
temperature=0.1, |
|
max_new_tokens=64, |
|
pad_token_id=tokenizer.eos_token_id, |
|
use_cache=False, |
|
) |
|
response = tokenizer.decode( |
|
output[0][input_length::], # response only, remove prompts |
|
skip_special_tokens=True, |
|
) |
|
print(response) |
|
|
|
``` |
|
> The graph compares the cumulative total return on our common stock, the NASDAQ Composite Index, and a group of all public companies sharing the same SIC code as us, which is SIC code 3711, "Motor Vehicles and Passenger Car Bodies." |