PEFT
English
medical
Edit model card

Model Card for Tonic/MistralMed

This is a medicine-focussed mistral fine tuned using keivalya/MedQuad-MedicalQnADataset

Model Details

Model Description

Trying to get better at medical Q & A

Model Sources

Uses

This model can be used the same way you normally use mistral

Direct Use

This model can do better in medical question and answer scenarios.

Downstream Use

This model is intended to be further fine tuned.

Recommendations

  • Do Not Use As Is
  • Fine Tune This Model Further
  • For Educational Purposes Only
  • Benchmark your model usage
  • Evaluate the model before use

Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.

How to Get Started with the Model

Use the code below to get started with the model.

pseudolab/MistralMED_Chat

from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import random
from textwrap import wrap

def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text

def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):
    formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"

    encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)

    output = model.generate(
        **model_inputs,
        max_length=max_length,
        use_cache=True,
        early_stopping=True,
        bos_token_id=model.config.bos_token_id,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.eos_token_id,
        temperature=0.1,
        do_sample=True
    )

    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text

device = "cuda" if torch.cuda.is_available() else "cpu"

base_model_id = "mistralai/Mistral-7B-v0.1"
model_directory = "Tonic/mistralmed"

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

peft_config = PeftConfig.from_pretrained("Tonic/mistralmed", token="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")
peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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")

class ChatBot:
    def __init__(self):
        self.history = []

    def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
        formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"

        user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")

        response = peft_model.generate(input_ids=user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)

        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text

bot = ChatBot()

title = "πŸ‘‹πŸ»ν† λ‹‰μ˜ λ―ΈμŠ€νŠΈλž„λ©”λ“œ μ±„νŒ…μ— μ˜€μ‹  것을 ν™˜μ˜ν•©λ‹ˆλ‹€πŸš€πŸ‘‹πŸ»Welcome to Tonic's MistralMed ChatπŸš€"
description = "이 곡간을 μ‚¬μš©ν•˜μ—¬ ν˜„μž¬ λͺ¨λΈμ„ ν…ŒμŠ€νŠΈν•  수 μžˆμŠ΅λ‹ˆλ‹€. [(Tonic/MistralMed)](https://huggingface.co/Tonic/MistralMed) λ˜λŠ” 이 곡간을 λ³΅μ œν•˜κ³  둜컬 λ˜λŠ” πŸ€—HuggingFaceμ—μ„œ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. [Discordμ—μ„œ ν•¨κ»˜ λ§Œλ“€κΈ° μœ„ν•΄ Discord에 κ°€μž…ν•˜μ‹­μ‹œμ˜€](https://discord.gg/VqTxc76K3u). You can use this Space to test out the current model [(Tonic/MistralMed)](https://huggingface.co/Tonic/MistralMed) or duplicate this Space and use it locally or on πŸ€—HuggingFace. [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [["[Question:] What is the proper treatment for buccal herpes?", "You are a medicine and public health expert, you will receive a question, answer the question, and complete the answer"]]

iface = gr.Interface(
    fn=bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=["text", "text"],
    outputs="text",
    theme="ParityError/Anime"
)

iface.launch()

Training Details

Training Data

MedQuad

Training Procedure

Dataset({
    features: ['qtype', 'Question', 'Answer'],
    num_rows: 16407
})

Preprocessing [optional]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

Training Hyperparameters

  • Training regime:
config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

Speeds, Sizes, Times [optional]

  • trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705
  • TrainOutput(global_step=1000, training_loss=0.47226515007019043, metrics={'train_runtime': 3143.4141, 'train_samples_per_second': 2.545, 'train_steps_per_second': 0.318, 'total_flos': 1.75274075357184e+17, 'train_loss': 0.47226515007019043, 'epoch': 0.49})

Environmental Impact

Carbon emissions can be estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019).

  • Hardware Type: A100
  • Hours used: 1
  • Cloud Provider: Google
  • Compute Region: East1
  • Carbon Emitted: 0.09

Training Results

[1000/1000 52:20, Epoch 0/1]

Step Training Loss
50 0.474200
100 0.523300
150 0.484500
200 0.482800
250 0.498800
300 0.451800
350 0.491800
400 0.488000
450 0.472800
500 0.460400
550 0.464700
600 0.484800
650 0.474600
700 0.477900
750 0.445300
800 0.431300
850 0.461500
900 0.451200
950 0.470800
1000 0.454900

Model Architecture and Objective

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              )
              (k_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              )
              (v_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              )
              (o_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              )
              (rotary_emb): MistralRotaryEmbedding()
            )
            (mlp): MistralMLP(
              (gate_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
              )
              (up_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
              )
              (down_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False)
              )
              (act_fn): SiLUActivation()
            )
            (input_layernorm): MistralRMSNorm()
            (post_attention_layernorm): MistralRMSNorm()
          )
        )
        (norm): MistralRMSNorm()
      )
      (lm_head): Linear(
        in_features=4096, out_features=32000, bias=False
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=4096, out_features=8, bias=False)
        )
        (lora_B): ModuleDict(
          (default): Linear(in_features=8, out_features=32000, bias=False)
        )
        (lora_embedding_A): ParameterDict()
        (lora_embedding_B): ParameterDict()
      )
    )
  )
)

Hardware

A100

Model Card Authors [optional]

Tonic

Model Card Contact

Tonic

Training procedure

The following bitsandbytes quantization config was used during training:

  • quant_method: bitsandbytes
  • load_in_8bit: False
  • load_in_4bit: True
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: True
  • bnb_4bit_compute_dtype: bfloat16

Framework versions

  • PEFT 0.6.0.dev0
Downloads last month
15
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for Tonic/mistralmed

Adapter
(1171)
this model

Dataset used to train Tonic/mistralmed

Spaces using Tonic/mistralmed 6