nicoladecao commited on
Commit
f168350
1 Parent(s): b47c541

Initial commit

Browse files
.gitattributes CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
29
+ tf_model.h5 filter=lfs diff=lfs merge=lfs -text
30
+ kilt_titles_trie_dict.pkl filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ language:
4
+ - en
5
+
6
+ tags:
7
+ - retrieval
8
+ - entity-retrieval
9
+ - named-entity-disambiguation
10
+ - entity-disambiguation
11
+ - named-entity-linking
12
+ - entity-linking
13
+ - text2text-generation
14
+ ---
15
+
16
+
17
+ # GENRE
18
+
19
+
20
+ The GENRE (Generative ENtity REtrieval) system as presented in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) implemented in pytorch.
21
+
22
+ In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned [BART](https://arxiv.org/abs/1910.13461) architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the [facebookresearch/GENRE](https://github.com/facebookresearch/GENRE) repository using `fairseq` (the `transformers` models are obtained with a conversion script similar to [this](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py).
23
+
24
+ This model was trained on the full training set of [BLINK](https://arxiv.org/abs/1911.03814) (i.e., 9M datapoints for entity-disambiguation grounded on Wikipedia).
25
+
26
+ ## BibTeX entry and citation info
27
+
28
+ **Please consider citing our works if you use code from this repository.**
29
+
30
+ ```bibtex
31
+ @inproceedings{decao2020autoregressive,
32
+ title={Autoregressive Entity Retrieval},
33
+ author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
34
+ booktitle={International Conference on Learning Representations},
35
+ url={https://openreview.net/forum?id=5k8F6UU39V},
36
+ year={2021}
37
+ }
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ Here is an example of generation for Wikipedia page disambiguation:
43
+
44
+ ```python
45
+ import pickle
46
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
47
+
48
+ # OPTIONAL: load the prefix tree (trie), you need to additionally download
49
+ # https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and
50
+ # https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl
51
+ # from trie import Trie
52
+ # with open("kilt_titles_trie_dict.pkl", "rb") as f:
53
+ # trie = Trie.load_from_dict(pickle.load(f))
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained("facebook/genre-linking-blink")
56
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/genre-linking-blink").eval()
57
+
58
+ sentences = ["Einstein was a [START_ENT] German [END_ENT] physicist."]
59
+
60
+ outputs = model.generate(
61
+ **tokenizer(sentences, return_tensors="pt"),
62
+ num_beams=5,
63
+ num_return_sequences=5,
64
+ # OPTIONAL: use constrained beam search
65
+ # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
66
+ )
67
+
68
+ tokenizer.batch_decode(outputs, skip_special_tokens=True)
69
+ ```
70
+ which outputs the following top-5 predictions (using constrained beam search)
71
+ ```
72
+ ['Germans',
73
+ 'Germany',
74
+ 'German Empire',
75
+ 'Weimar Republic',
76
+ 'Greeks']
77
+ ```
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/genre-kilt",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "BartForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.0,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 12,
19
+ "decoder_start_token_id": 2,
20
+ "dropout": 0.1,
21
+ "early_stopping": true,
22
+ "encoder_attention_heads": 16,
23
+ "encoder_ffn_dim": 4096,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 12,
26
+ "eos_token_id": 2,
27
+ "eos_token_ids": [
28
+ 2
29
+ ],
30
+ "forced_eos_token_id": 2,
31
+ "gradient_checkpointing": false,
32
+ "init_std": 0.02,
33
+ "is_encoder_decoder": true,
34
+ "max_length": 1024,
35
+ "max_position_embeddings": 1024,
36
+ "min_length": 0,
37
+ "model_type": "bart",
38
+ "normalize_before": false,
39
+ "normalize_embedding": false,
40
+ "num_beams": 6,
41
+ "num_hidden_layers": 12,
42
+ "output_past": true,
43
+ "pad_token_id": 1,
44
+ "replacing_rate": 0,
45
+ "scale_embedding": false,
46
+ "static_position_embeddings": false,
47
+ "student_decoder_layers": null,
48
+ "student_encoder_layers": null,
49
+ "task_specific_params": {},
50
+ "transformers_version": "4.19.2",
51
+ "use_cache": true,
52
+ "vocab_size": 50264
53
+ }
kilt_titles_trie_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951db72cc702fcf6639419efcf917cb7f3c67cc6202ebe3ae3ca399c30614da2
3
+ size 215214973
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d105d545961fe8eec7183bab63dd5dea9acf4cd69783827a4151bda989635d1e
3
+ size 1625526529
tf_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a854657f17fa38492440f5111c7f78e1e1bdd75e58eff59b5260894ba183e58b
3
+ size 1625921384
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 1024}
trie.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree at
6
+ # https://github.com/facebookresearch/GENRE .
7
+
8
+
9
+ from typing import Dict, List
10
+
11
+
12
+ class Trie(object):
13
+ def __init__(self, sequences: List[List[int]] = []):
14
+ self.trie_dict = {}
15
+ self.len = 0
16
+ if sequences:
17
+ for sequence in sequences:
18
+ Trie._add_to_trie(sequence, self.trie_dict)
19
+ self.len += 1
20
+
21
+ self.append_trie = None
22
+ self.bos_token_id = None
23
+
24
+ def append(self, trie, bos_token_id):
25
+ self.append_trie = trie
26
+ self.bos_token_id = bos_token_id
27
+
28
+ def add(self, sequence: List[int]):
29
+ Trie._add_to_trie(sequence, self.trie_dict)
30
+ self.len += 1
31
+
32
+ def get(self, prefix_sequence: List[int]):
33
+ return Trie._get_from_trie(
34
+ prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
35
+ )
36
+
37
+ @staticmethod
38
+ def load_from_dict(trie_dict):
39
+ trie = Trie()
40
+ trie.trie_dict = trie_dict
41
+ trie.len = sum(1 for _ in trie)
42
+ return trie
43
+
44
+ @staticmethod
45
+ def _add_to_trie(sequence: List[int], trie_dict: Dict):
46
+ if sequence:
47
+ if sequence[0] not in trie_dict:
48
+ trie_dict[sequence[0]] = {}
49
+ Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
50
+
51
+ @staticmethod
52
+ def _get_from_trie(
53
+ prefix_sequence: List[int],
54
+ trie_dict: Dict,
55
+ append_trie=None,
56
+ bos_token_id: int = None,
57
+ ):
58
+ if len(prefix_sequence) == 0:
59
+ output = list(trie_dict.keys())
60
+ if append_trie and bos_token_id in output:
61
+ output.remove(bos_token_id)
62
+ output += list(append_trie.trie_dict.keys())
63
+ return output
64
+ elif prefix_sequence[0] in trie_dict:
65
+ return Trie._get_from_trie(
66
+ prefix_sequence[1:],
67
+ trie_dict[prefix_sequence[0]],
68
+ append_trie,
69
+ bos_token_id,
70
+ )
71
+ else:
72
+ if append_trie:
73
+ return append_trie.get(prefix_sequence)
74
+ else:
75
+ return []
76
+
77
+ def __iter__(self):
78
+ def _traverse(prefix_sequence, trie_dict):
79
+ if trie_dict:
80
+ for next_token in trie_dict:
81
+ yield from _traverse(
82
+ prefix_sequence + [next_token], trie_dict[next_token]
83
+ )
84
+ else:
85
+ yield prefix_sequence
86
+
87
+ return _traverse([], self.trie_dict)
88
+
89
+ def __len__(self):
90
+ return self.len
91
+
92
+ def __getitem__(self, value):
93
+ return self.get(value)
vocab.json ADDED
The diff for this file is too large to render. See raw diff