mistral_context_qa / README.md
jingwang's picture
Update README.md
4059575 verified
---
language:
- en
license: apache-2.0
tags:
- text-generation-inference
- transformers
- unsloth
- mistral
- trl
base_model: unsloth/mistral-7b-v0.3-bnb-4bit
---
# Uploaded model
- **Developed by:** jingwang
- **License:** apache-2.0
- **Finetuned from model :** unsloth/mistral-7b-v0.3-bnb-4bit
This mistral model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
# install dependencies in google colab
```shell
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes
```
# inference
```python
from unsloth import FastLanguageModel
from typing import Dict, List, Tuple, Union, Any
import pandas
from tqdm import trange, tqdm
import torch
class FormatPrompt_context_QA():
'''format prompt class'''
def __init__(self, eos_token:str='</s>') -> None:
self.inputs = ['context','question'] # required input fields
self.outputs = ['answer'] # for training, and model inference output fields
self.eos_token = eos_token
def __call__(self, instance: Dict[str, Any]) -> str:
'''
function call operator
Args:
instance: dictionary with keys: 'context', 'question', 'answer'
Returns:
prompt: formatted prompt
'''
return self.formatting_prompt_func(instance)
def formatting_prompt_func(self, instance: dict) -> str:
'''format prompt for domain specific QA
note this is for fine-tuning pre-trained model,
if starting with instuct tuned model, use `tokenizer.apply_chat_template(messages)` instead
'''
assert all([ item in instance.keys() for item in self.inputs ]), logging.info(f"instance must have {self.inputs}!")
prompt = f"""<s> [INST] Answer following question based on Context: {str(instance["context"])}\
Question: {str(instance["question"])} \
Answer: [/INST]"""
if 'answer' in instance:
prompt += str(instance['answer']) + self.eos_token
return prompt
```
```python
formatting_func = FormatPrompt_context_QA()
# pull model from huggingface
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "jingwang/mistral_context_qa",
max_seq_length = 2048,
dtype = None,
load_in_4bit = True,
)
FastLanguageModel.for_inference(model)
example = {'question': 'What does the graph compare in terms of cumulative total return?',
'context': 'the following graph shows a comparison, from january 1, 2019 through december 31, 2023, of the cumulative total return on our common stock, the nasdaq composite index and a group of all public companies sharing the same sic code as us, which is sic code 3711, “ motor vehicles and passenger car bodies ” ( motor vehicles and passenger car bodies public company group ). such returns are based on historical results and are not intended to suggest future performance. data for the nasdaq composite index and the motor vehicles and passenger car bodies public company group assumes an investment of $ 100 on january 1, 2019 and reinvestment of dividends. we have never declared or paid cash dividends on our common stock nor do we anticipate paying any such cash dividends in the foreseeable future. 31',
'gold_answer': "The graph compares the cumulative total return from January 1, 2019, through December 31, 2023, of the company's common stock, the NASDAQ Composite Index, and a group of public companies with the same SIC code (3711 - Motor Vehicles and Passenger Car Bodies). The comparison assumes an initial investment of $100 on January 1, 2019, with reinvestment of dividends for the NASDAQ Composite Index and the Motor Vehicles and Passenger Car Bodies public company group.",
}
inputs = tokenizer([formatting_func(example)], return_tensors="pt", padding=False).to(model.device)
input_length = inputs.input_ids.shape[-1]
with torch.no_grad():
output = model.generate(**inputs,
do_sample=False,
temperature=0.1,
max_new_tokens=64,
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
)
response = tokenizer.decode(
output[0][input_length::], # response only, remove prompts
skip_special_tokens=True,
)
print(response)
```
> The graph compares the cumulative total return on our common stock, the NASDAQ Composite Index, and a group of all public companies sharing the same SIC code as us, which is SIC code 3711, "Motor Vehicles and Passenger Car Bodies."