tdelic's picture
Update README.md
9e83cf4 verified
|
raw
history blame
4.76 kB
metadata
license: apache-2.0

Fine-tuned Mistral Model for Multi-Document Summarization

This model a fine-tuned model based on mistralai/Mistral-7B-v0.1 on multi_x_science_sum dataset.

Model description

Mistral-7B-multixscience-finetuned is finetuned on multi_x_science_sum dataset in order to extend the capabilities of the original Mistral model in multi-document summarization tasks. The fine-tuned model leverages the power of Mistral model fundation, adapting it to synthesize and summarize information from multiple documents efficiently.

Training and evaluation dataset

Multi_x_science_sum is a large-scale multi-document summarization dataset created from scientific articles. Multi-XScience introduces a challenging multi-document summarization task: writing the related-work section of a paper based on its abstract and the articles it references.

The training and evaluation datasets were uniquely generated to facilitate the fine-tuning of the model for multi-document summarization, particularly focusing on generating related work sections for scientific papers. Using a custom-designed prompt-generation process, the dataset is created to simulate the task of synthesizing related work sections based on a given paper's abstract and the abstracts of its referenced papers.

Dataset Generation process

The process involves generating prompts that instruct the model to use the abstract of the current paper along with the abstracts of cited papers to generate a new related work section. This approach aims to mimic the real-world scenario where a researcher synthesizes information from multiple sources to draft the related work section of a paper.

  • Prompt Structure: Each data point consists of an instructional prompt that includes:

    • The abstract of the current paper.
    • Abstracts from cited papers, labeled with unique identifiers.
    • An expected model response in the form of a generated related work section.

Prompt generation Code

def generate_related_work_prompt(data):
    prompt = "[INST] <<SYS>>\n"
    prompt += "Use the abstract of the current paper and the abstracts of the cited papers to generate new related work.\n"
    prompt += "<</SYS>>\n\n"
    prompt += "Input:\nCurrent Paper's Abstract:\n- {}\n\n".format(data['abstract'])
    prompt += "Cited Papers' Abstracts:\n"
    for cite_id, cite_abstract in zip(data['ref_abstract']['cite_N'], data['ref_abstract']['abstract']):
        prompt += "- {}: {}\n".format(cite_id, cite_abstract)
    prompt += "\n[/INST]\n\nGenerated Related Work:\n{}\n".format(data['related_work'])
    return {"text": prompt}

The dataset generated through this process was used to train and evaluate the finetuned model, ensuring that it learns to accurately synthesize information from multiple sources into cohesive summaries.

Training hyperparameters

The following hyperparameters were used during training:

learning_rate: 2e-5
train_batch_size: 4
eval_batch_size: 4
seed: 42
optimizer: adamw_8bit
num_epochs: 5

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel

base_model = "mistralai/Mistral-7B-v0.1"
adapter = "OctaSpace/Mistral7B-fintuned"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    add_bos_token=True,
    trust_remote_code=True,
    padding_side='left'
)

# Create peft model using base_model and finetuned adapter
config = PeftConfig.from_pretrained(adapter)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
                                             load_in_4bit=True,
                                             device_map='auto',
                                             torch_dtype='auto')
model = PeftModel.from_pretrained(model, adapter)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# Prompt content:
messages = [] # Put here your related work generation instruction

input_ids = tokenizer.apply_chat_template(conversation=messages,
                                          tokenize=True,
                                          add_generation_prompt=True,
                                          return_tensors='pt').to(device)
summary_ids = model.generate(input_ids=input_ids, max_new_tokens=512, do_sample=True, pad_token_id=2)
summaries = tokenizer.batch_decode(summary_ids.detach().cpu().numpy(), skip_special_tokens = True)

# Model response: 
print(summaries[0])