File size: 3,675 Bytes
933f33d
4869469
 
933f33d
4869469
8d1242c
 
 
 
 
 
4869469
 
 
8d1242c
 
 
 
 
4869469
8d1242c
 
 
 
 
9f6cd8c
 
 
 
 
 
 
 
933f33d
8d1242c
 
 
 
 
 
 
 
78b7731
 
dd45369
78b7731
cf2ad4d
8d1242c
 
 
e830b01
 
8d1242c
 
 
 
b3dd884
8d1242c
 
 
 
 
 
 
 
 
 
 
 
 
 
e830b01
4001f88
e830b01
 
 
 
 
 
 
 
 
 
 
4001f88
e830b01
 
 
 
 
 
 
4001f88
e830b01
4001f88
e830b01
 
 
 
 
 
 
 
8d1242c
 
 
 
 
78b7731
 
 
7392310
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
language:
- en
license: apache-2.0
library_name: adapter-transformers
tags:
- generated_from_trainer
datasets:
- samsum
metrics:
- rouge
pipeline_tag: summarization
inference: true
base_model: braindao/flan-t5-cnn
model-index:
- name: flan-t5-base
  results:
  - task:
      type: summarization
      name: Summarization
    dataset:
      name: samsum
      type: samsum
      split: validation
    metrics:
    - type: rogue1
      value: 46.819522%
    - type: rouge2
      value: 20.898074%
    - type: rougeL
      value: 37.300937%
    - type: rougeLsum
      value: 37.271341%
---


# flan-t5-base-cnn-samsum-lora

This model is a fine-tuned version of [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) on the [samsum](https://huggingface.co/datasets/samsum) dataset.

The base model [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) is a fine-tuned verstion of [google/flan-t5-base](https://huggingface.co/google/flan-t5-base) on the cnn_dailymail 3.0.0 dataset.

## Model API Spaces

Please visit HF Spaces [sooolee/summarize-transcripts-gradio](https://huggingface.co/spaces/sooolee/summarize-transcripts-gradio) for Gradio API.  This API takes YouTube 'Video_ID' as the input. 


## Model description

* This model further finetuned [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) on the more conversational samsum dataset.  
* Huggingface [PEFT Library](https://github.com/huggingface/peft) LoRA (r = 16) and bitsandbytes int-8 was used to speed up training and reduce the model size.
* Only 1.7M parameters were trained (0.71% of original flan-t5-base 250M parameters).
* The model checkpoint is just 7MB.

## Intended uses & limitations

Summarize transcripts such as YouTube transcripts. 

## Training and evaluation data
### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 0.001
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 5

### Training results

- train_loss: 1.47

### How to use
Note 'max_new_tokens=60' is used in the below example to control the length of the summary. FLAN-T5 model has max generation length = 200  and min generation length = 20 (default).

```python
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load peft config for pre-trained checkpoint etc.
peft_model_id = "sooolee/flan-t5-base-cnn-samsum-lora"
config = PeftConfig.from_pretrained(peft_model_id)

# load base LLM model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map='auto') # load_in_8bit=True, 
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto')

# Tokenize the text inputs
texts = "<e.g. Part of YouTube Transcript>"
inputs = tokenizer(texts, return_tensors="pt", padding=True, ) # truncation=True

# Make inferences
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():    
    output = self.model.generate(input_ids=inputs["input_ids"].to(device), max_new_tokens=60, do_sample=True, top_p=0.9)
    summary = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True)

summary
```

### Framework versions

- Transformers 4.27.2
- Pytorch 1.13.1+cu116
- Datasets 2.9.0
- Tokenizers 0.13.3

## Other
Please check out the BART-Large-CNN-Samsum model fine-tuned for the same purpose: [sooolee/bart-large-cnn-finetuned-samsum-lora](https://huggingface.co/sooolee/bart-large-cnn-finetuned-samsum-lora)