NotXia commited on
Commit
6f408cf
1 Parent(s): 2bde500
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LongformerSummarizer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.ExtSummConfig",
7
+ "AutoModel": "model.LongformerSummarizer"
8
+ },
9
+ "custom_pipelines": {
10
+ "summarization": {
11
+ "impl": "pipeline.ExtSummPipeline",
12
+ "pt": [
13
+ "AutoModel"
14
+ ],
15
+ "tf": []
16
+ }
17
+ },
18
+ "input_size": 4096,
19
+ "model_type": "longformer-bio-ext-summ",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.30.2"
22
+ }
configuration.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ExtSummConfig(PretrainedConfig):
5
+ model_type = "longformer-bio-ext-summ"
6
+
7
+ def __init__(
8
+ self,
9
+ input_size: int = 4096,
10
+ **kwargs
11
+ ):
12
+ self.input_size = input_size
13
+ super().__init__(**kwargs)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .transformerutils import TransformerInterEncoder
3
+ from transformers import PreTrainedModel, AutoModel, LongformerConfig
4
+ from .configuration import ExtSummConfig
5
+
6
+
7
+
8
+ class LongformerSummarizer(PreTrainedModel):
9
+ config_class = ExtSummConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+ self.longformer = AutoModel.from_config(LongformerConfig.from_pretrained("allenai/longformer-base-4096"))
14
+ self.input_size = config.input_size
15
+ self.interSentenceEncoder = TransformerInterEncoder(self.longformer.config.hidden_size, max_len=4096)
16
+
17
+
18
+ def forward(self, batch):
19
+ document_ids = batch["ids"].to(self.longformer.device)
20
+ clss_mask = batch["clss_mask"].to(self.longformer.device)
21
+ attn_mask = batch["attn_mask"].to(self.longformer.device)
22
+ global_attn_mask = batch["global_attn_mask"].to(self.longformer.device)
23
+
24
+ tokens_out, _ = self.longformer(input_ids=document_ids, attention_mask=attn_mask, global_attention_mask=global_attn_mask, return_dict=False)
25
+ out = []
26
+ logits_out = []
27
+
28
+ for i in range(len(tokens_out)): # Batch handling
29
+ clss_out = tokens_out[i][clss_mask[i], :]
30
+ sentences_scores, logits = self.interSentenceEncoder(clss_out)
31
+ padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device)
32
+
33
+ out.append( torch.cat((sentences_scores, padding)) )
34
+ logits_out.append( torch.cat((logits, padding)) )
35
+
36
+ return torch.stack(out), torch.stack(logits_out)
pipeline.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch
3
+ from .utilities import padToSize
4
+ from .summary import select, splitDocument
5
+
6
+
7
+
8
+ class ExtSummPipeline(Pipeline):
9
+ """
10
+ Extractive summarization pipeline
11
+
12
+ Inputs
13
+ ------
14
+ inputs : dict
15
+ 'sentences' : list[str]
16
+ Sentences of the document
17
+
18
+ strategy : str
19
+ Strategy to summarize the document:
20
+ - 'length': summary with a maximum length (strategy_args is the maximum length).
21
+ - 'count': summary with the given number of sentences (strategy_args is the number of sentences).
22
+ - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]).
23
+ - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score).
24
+
25
+ strategy_args : any
26
+ Parameters of the strategy.
27
+
28
+ Outputs
29
+ -------
30
+ selected_sents : list[str]
31
+ List of the selected sentences
32
+
33
+ selected_idxs : list[int]
34
+ List of the indexes of the selected sentences in the original input
35
+ """
36
+
37
+
38
+ def _sanitize_parameters(self, **kwargs):
39
+ postprocess_kwargs = {}
40
+
41
+ if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs):
42
+ raise ValueError("`strategy` and `strategy_args` have to be both set")
43
+ if "strategy" in kwargs:
44
+ postprocess_kwargs["strategy"] = kwargs["strategy"]
45
+ if "strategy_args" in kwargs:
46
+ postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
47
+
48
+ return {}, {}, postprocess_kwargs
49
+
50
+
51
+ def preprocess(self, inputs):
52
+ sentences = inputs["sentences"]
53
+
54
+ # Tokenization and chunking
55
+ doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) )
56
+ doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token]
57
+ doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size)
58
+
59
+ # Batch preparation
60
+ batch = {
61
+ "ids": [],
62
+ "clss_mask": [],
63
+ "attn_mask": [],
64
+ "global_attn_mask": [],
65
+ }
66
+ for chunk_tokens in doc_chunks:
67
+ doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens)
68
+ clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids]
69
+ attn_mask = [1 for _ in range(len(doc_ids))]
70
+ global_attn_mask = [1 if token == self.tokenizer.cls_token_id else 0 for token in doc_ids]
71
+
72
+ batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) )
73
+ batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) )
74
+ batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) )
75
+ batch["global_attn_mask"].append( padToSize(global_attn_mask, self.model.config.input_size, 0) )
76
+
77
+ batch["ids"] = torch.as_tensor(batch["ids"])
78
+ batch["clss_mask"] = torch.as_tensor(batch["clss_mask"])
79
+ batch["attn_mask"] = torch.as_tensor(batch["attn_mask"])
80
+ batch["global_attn_mask"] = torch.as_tensor(batch["global_attn_mask"])
81
+ return { "inputs": batch, "sentences": sentences }
82
+
83
+
84
+ def _forward(self, args):
85
+ batch = args["inputs"]
86
+ sentences = args["sentences"]
87
+ out_predictions = torch.as_tensor([]).to(self.device)
88
+
89
+ self.model.eval()
90
+ with torch.no_grad():
91
+ batch_preds, _ = self.model(batch)
92
+ for i, clss_mask in enumerate(batch["clss_mask"]):
93
+ out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)]))
94
+
95
+ return { "predictions": out_predictions, "sentences": sentences }
96
+
97
+
98
+ def postprocess(self, args, strategy: str="count", strategy_args=3):
99
+ predictions = args["predictions"]
100
+ sentences = args["sentences"]
101
+ return select(sentences, predictions, strategy, strategy_args)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19bfa03d5f00f4e0fa6b5af68e53ad26cfc4ee6ab11cee26560272f8df2500b4
3
+ size 651442429
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
summary.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _selectStrategyLength(sentences, predictions, max_length):
5
+ selected_sents = []
6
+ sents_priority = torch.argsort(predictions, descending=True)
7
+ summary_len = 0
8
+ i = 0
9
+
10
+ while (summary_len < max_length) and (i < len(sents_priority)):
11
+ if summary_len + len(sentences[sents_priority[i]]) < max_length:
12
+ selected_sents.append(sents_priority[i].item())
13
+ summary_len += len(sentences[sents_priority[i]])
14
+ i += 1
15
+
16
+ return sorted(selected_sents)
17
+
18
+
19
+ def _selectStrategyCount(sentences, predictions, num_sents):
20
+ selected_idxs = sorted(torch.topk(predictions, min(len(predictions), num_sents)).indices)
21
+ return [tensor.item() for tensor in selected_idxs]
22
+
23
+
24
+ def _selectStrategyRatio(sentences, predictions, ratio):
25
+ doc_length = sum([ len(sent) for sent in sentences ])
26
+ return _selectStrategyLength(sentences, predictions, doc_length*ratio)
27
+
28
+
29
+ def _selectStrategyThreshold(sentences, predictions, threshold):
30
+ return [i for i, score in enumerate(predictions) if score >= threshold]
31
+
32
+
33
+ def select(sentences, predictions, strategy, strategy_args):
34
+ selected_sents = []
35
+
36
+ if strategy == "length":
37
+ selected_sents = _selectStrategyLength(sentences, predictions, strategy_args)
38
+ elif strategy == "count":
39
+ selected_sents = _selectStrategyCount(sentences, predictions, strategy_args)
40
+ elif strategy == "ratio":
41
+ selected_sents = _selectStrategyRatio(sentences, predictions, strategy_args)
42
+ elif strategy == "threshold":
43
+ selected_sents = _selectStrategyThreshold(sentences, predictions, strategy_args)
44
+ else:
45
+ raise NotImplementedError(f"Unknown strategy {strategy}")
46
+
47
+ return [sentences[i] for i in selected_sents], selected_sents
48
+
49
+
50
+
51
+ """
52
+ Splits a document in chunks of maximum a given size.
53
+
54
+ Parameters
55
+ ----------
56
+ doc_tokens : str[]
57
+ List of the tokens of the document.
58
+
59
+ bos_token : str
60
+ Begin of sentence token.
61
+
62
+ eos_token : str
63
+ End of sentence token.
64
+
65
+ max_size : int
66
+ Maximum size of a chunk.
67
+ Returns
68
+ -------
69
+ chunks : str[][]
70
+ Splitted document.
71
+ """
72
+ def splitDocument(doc_tokens, bos_token, eos_token, max_size):
73
+ def _findNextBOSFrom(start_idx):
74
+ for i in range(start_idx, len(doc_tokens)):
75
+ if doc_tokens[i] == bos_token:
76
+ return i
77
+ return -1
78
+
79
+ def _findPreviousEOSFrom(start_idx):
80
+ for i in range(start_idx, -1, -1):
81
+ if doc_tokens[i] == eos_token:
82
+ return i
83
+ return -1
84
+
85
+ chunks = []
86
+
87
+ while len(doc_tokens) > max_size:
88
+ # Splits at the eos token
89
+ eos_idx = _findPreviousEOSFrom(max_size - 1)
90
+
91
+ if eos_idx == -1:
92
+ # The sentence is too long.
93
+ # Find the next bos in front of the current sentence (if exists) and truncate the current sentence.
94
+ next_bos_idx = _findNextBOSFrom(max_size)
95
+ if next_bos_idx != -1:
96
+ doc_tokens = doc_tokens[:max_size-1] + [eos_token] + doc_tokens[next_bos_idx:]
97
+ else:
98
+ doc_tokens = doc_tokens[:max_size-1] + [eos_token]
99
+ eos_idx = max_size - 1
100
+
101
+ chunks.append(doc_tokens[:eos_idx+1])
102
+ doc_tokens = doc_tokens[eos_idx+1:]
103
+
104
+ if len(doc_tokens) > 0: chunks.append(doc_tokens) # Remaining part of the document
105
+
106
+ return chunks
tokenizer_config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "cls_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "eos_token": {
21
+ "__type": "AddedToken",
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "errors": "replace",
29
+ "mask_token": {
30
+ "__type": "AddedToken",
31
+ "content": "<mask>",
32
+ "lstrip": true,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "model_max_length": 4096,
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "tokenizer_class": "LongformerTokenizer",
55
+ "unk_token": {
56
+ "__type": "AddedToken",
57
+ "content": "<unk>",
58
+ "lstrip": false,
59
+ "normalized": true,
60
+ "rstrip": false,
61
+ "single_word": false
62
+ }
63
+ }
transformerutils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+
7
+ # Source: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
10
+ super().__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+
13
+ position = torch.arange(max_len).unsqueeze(1)
14
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
15
+ pe = torch.zeros(max_len, d_model)
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ self.register_buffer("pe", pe)
19
+
20
+ def forward(self, x):
21
+ x = x + self.pe[:x.size(0)]
22
+ return self.dropout(x)
23
+
24
+
25
+ """
26
+ Same scheduler as in "Attention Is All You Need"
27
+ """
28
+ class NoamScheduler():
29
+ def __init__(self, optimizer, warmup, model_size):
30
+ self.epoch = 0
31
+ self.optimizer = optimizer
32
+ self.warmup = warmup
33
+ self.model_size = model_size
34
+
35
+ def step(self):
36
+ self.epoch += 1
37
+ new_lr = self.model_size**(-0.5) * min(self.epoch**(-0.5), self.epoch * self.warmup**(-1.5))
38
+
39
+ for param in self.optimizer.param_groups:
40
+ param["lr"] = new_lr
41
+
42
+
43
+ """
44
+ Encoders to attend sentence level features.
45
+ """
46
+ class TransformerInterEncoder(nn.Module):
47
+ def __init__(self, d_model, d_ff=2048, nheads=8, num_encoders=2, dropout=0.1, max_len=512):
48
+ super().__init__()
49
+ self.positional_enc = PositionalEncoding(d_model, dropout, max_len)
50
+ self.encoders = nn.TransformerEncoder(
51
+ nn.TransformerEncoderLayer(d_model=d_model, nhead=nheads, dim_feedforward=d_ff),
52
+ num_layers=num_encoders
53
+ )
54
+ self.layer_norm = nn.LayerNorm(d_model)
55
+ self.linear = nn.Linear(d_model, 1)
56
+ self.sigmoid = nn.Sigmoid()
57
+
58
+ def forward(self, x):
59
+ x = self.positional_enc(x)
60
+ x = self.encoders(x)
61
+ x = self.layer_norm(x)
62
+ logit = self.linear(x)
63
+ sentences_scores = self.sigmoid(logit)
64
+
65
+ return sentences_scores.squeeze(-1), logit.squeeze(-1)
utilities.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Pads a list to a given size
3
+ """
4
+ def padToSize(to_pad_list, pad_size, filler):
5
+ return to_pad_list + [filler]*(pad_size-len(to_pad_list))
vocab.json ADDED
The diff for this file is too large to render. See raw diff