nicoladecao
commited on
Commit
•
f168350
1
Parent(s):
b47c541
Initial commit
Browse files- .gitattributes +3 -0
- README.md +77 -0
- config.json +53 -0
- kilt_titles_trie_dict.pkl +3 -0
- merges.txt +0 -0
- pytorch_model.bin +3 -0
- tf_model.h5 +3 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- trie.py +93 -0
- vocab.json +0 -0
.gitattributes
CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
29 |
+
tf_model.h5 filter=lfs diff=lfs merge=lfs -text
|
30 |
+
kilt_titles_trie_dict.pkl filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
|
6 |
+
tags:
|
7 |
+
- retrieval
|
8 |
+
- entity-retrieval
|
9 |
+
- named-entity-disambiguation
|
10 |
+
- entity-disambiguation
|
11 |
+
- named-entity-linking
|
12 |
+
- entity-linking
|
13 |
+
- text2text-generation
|
14 |
+
---
|
15 |
+
|
16 |
+
|
17 |
+
# GENRE
|
18 |
+
|
19 |
+
|
20 |
+
The GENRE (Generative ENtity REtrieval) system as presented in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) implemented in pytorch.
|
21 |
+
|
22 |
+
In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned [BART](https://arxiv.org/abs/1910.13461) architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the [facebookresearch/GENRE](https://github.com/facebookresearch/GENRE) repository using `fairseq` (the `transformers` models are obtained with a conversion script similar to [this](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py).
|
23 |
+
|
24 |
+
This model was trained on the full training set of [BLINK](https://arxiv.org/abs/1911.03814) (i.e., 9M datapoints for entity-disambiguation grounded on Wikipedia).
|
25 |
+
|
26 |
+
## BibTeX entry and citation info
|
27 |
+
|
28 |
+
**Please consider citing our works if you use code from this repository.**
|
29 |
+
|
30 |
+
```bibtex
|
31 |
+
@inproceedings{decao2020autoregressive,
|
32 |
+
title={Autoregressive Entity Retrieval},
|
33 |
+
author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
|
34 |
+
booktitle={International Conference on Learning Representations},
|
35 |
+
url={https://openreview.net/forum?id=5k8F6UU39V},
|
36 |
+
year={2021}
|
37 |
+
}
|
38 |
+
```
|
39 |
+
|
40 |
+
## Usage
|
41 |
+
|
42 |
+
Here is an example of generation for Wikipedia page disambiguation:
|
43 |
+
|
44 |
+
```python
|
45 |
+
import pickle
|
46 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
47 |
+
|
48 |
+
# OPTIONAL: load the prefix tree (trie), you need to additionally download
|
49 |
+
# https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and
|
50 |
+
# https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl
|
51 |
+
# from trie import Trie
|
52 |
+
# with open("kilt_titles_trie_dict.pkl", "rb") as f:
|
53 |
+
# trie = Trie.load_from_dict(pickle.load(f))
|
54 |
+
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/genre-linking-blink")
|
56 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/genre-linking-blink").eval()
|
57 |
+
|
58 |
+
sentences = ["Einstein was a [START_ENT] German [END_ENT] physicist."]
|
59 |
+
|
60 |
+
outputs = model.generate(
|
61 |
+
**tokenizer(sentences, return_tensors="pt"),
|
62 |
+
num_beams=5,
|
63 |
+
num_return_sequences=5,
|
64 |
+
# OPTIONAL: use constrained beam search
|
65 |
+
# prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
|
66 |
+
)
|
67 |
+
|
68 |
+
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
69 |
+
```
|
70 |
+
which outputs the following top-5 predictions (using constrained beam search)
|
71 |
+
```
|
72 |
+
['Germans',
|
73 |
+
'Germany',
|
74 |
+
'German Empire',
|
75 |
+
'Weimar Republic',
|
76 |
+
'Greeks']
|
77 |
+
```
|
config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/genre-kilt",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"add_bias_logits": false,
|
6 |
+
"add_final_layer_norm": false,
|
7 |
+
"architectures": [
|
8 |
+
"BartForConditionalGeneration"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"bos_token_id": 0,
|
12 |
+
"classif_dropout": 0.0,
|
13 |
+
"classifier_dropout": 0.0,
|
14 |
+
"d_model": 1024,
|
15 |
+
"decoder_attention_heads": 16,
|
16 |
+
"decoder_ffn_dim": 4096,
|
17 |
+
"decoder_layerdrop": 0.0,
|
18 |
+
"decoder_layers": 12,
|
19 |
+
"decoder_start_token_id": 2,
|
20 |
+
"dropout": 0.1,
|
21 |
+
"early_stopping": true,
|
22 |
+
"encoder_attention_heads": 16,
|
23 |
+
"encoder_ffn_dim": 4096,
|
24 |
+
"encoder_layerdrop": 0.0,
|
25 |
+
"encoder_layers": 12,
|
26 |
+
"eos_token_id": 2,
|
27 |
+
"eos_token_ids": [
|
28 |
+
2
|
29 |
+
],
|
30 |
+
"forced_eos_token_id": 2,
|
31 |
+
"gradient_checkpointing": false,
|
32 |
+
"init_std": 0.02,
|
33 |
+
"is_encoder_decoder": true,
|
34 |
+
"max_length": 1024,
|
35 |
+
"max_position_embeddings": 1024,
|
36 |
+
"min_length": 0,
|
37 |
+
"model_type": "bart",
|
38 |
+
"normalize_before": false,
|
39 |
+
"normalize_embedding": false,
|
40 |
+
"num_beams": 6,
|
41 |
+
"num_hidden_layers": 12,
|
42 |
+
"output_past": true,
|
43 |
+
"pad_token_id": 1,
|
44 |
+
"replacing_rate": 0,
|
45 |
+
"scale_embedding": false,
|
46 |
+
"static_position_embeddings": false,
|
47 |
+
"student_decoder_layers": null,
|
48 |
+
"student_encoder_layers": null,
|
49 |
+
"task_specific_params": {},
|
50 |
+
"transformers_version": "4.19.2",
|
51 |
+
"use_cache": true,
|
52 |
+
"vocab_size": 50264
|
53 |
+
}
|
kilt_titles_trie_dict.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:951db72cc702fcf6639419efcf917cb7f3c67cc6202ebe3ae3ca399c30614da2
|
3 |
+
size 215214973
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d105d545961fe8eec7183bab63dd5dea9acf4cd69783827a4151bda989635d1e
|
3 |
+
size 1625526529
|
tf_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a854657f17fa38492440f5111c7f78e1e1bdd75e58eff59b5260894ba183e58b
|
3 |
+
size 1625921384
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"model_max_length": 1024}
|
trie.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree at
|
6 |
+
# https://github.com/facebookresearch/GENRE .
|
7 |
+
|
8 |
+
|
9 |
+
from typing import Dict, List
|
10 |
+
|
11 |
+
|
12 |
+
class Trie(object):
|
13 |
+
def __init__(self, sequences: List[List[int]] = []):
|
14 |
+
self.trie_dict = {}
|
15 |
+
self.len = 0
|
16 |
+
if sequences:
|
17 |
+
for sequence in sequences:
|
18 |
+
Trie._add_to_trie(sequence, self.trie_dict)
|
19 |
+
self.len += 1
|
20 |
+
|
21 |
+
self.append_trie = None
|
22 |
+
self.bos_token_id = None
|
23 |
+
|
24 |
+
def append(self, trie, bos_token_id):
|
25 |
+
self.append_trie = trie
|
26 |
+
self.bos_token_id = bos_token_id
|
27 |
+
|
28 |
+
def add(self, sequence: List[int]):
|
29 |
+
Trie._add_to_trie(sequence, self.trie_dict)
|
30 |
+
self.len += 1
|
31 |
+
|
32 |
+
def get(self, prefix_sequence: List[int]):
|
33 |
+
return Trie._get_from_trie(
|
34 |
+
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
|
35 |
+
)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def load_from_dict(trie_dict):
|
39 |
+
trie = Trie()
|
40 |
+
trie.trie_dict = trie_dict
|
41 |
+
trie.len = sum(1 for _ in trie)
|
42 |
+
return trie
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def _add_to_trie(sequence: List[int], trie_dict: Dict):
|
46 |
+
if sequence:
|
47 |
+
if sequence[0] not in trie_dict:
|
48 |
+
trie_dict[sequence[0]] = {}
|
49 |
+
Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def _get_from_trie(
|
53 |
+
prefix_sequence: List[int],
|
54 |
+
trie_dict: Dict,
|
55 |
+
append_trie=None,
|
56 |
+
bos_token_id: int = None,
|
57 |
+
):
|
58 |
+
if len(prefix_sequence) == 0:
|
59 |
+
output = list(trie_dict.keys())
|
60 |
+
if append_trie and bos_token_id in output:
|
61 |
+
output.remove(bos_token_id)
|
62 |
+
output += list(append_trie.trie_dict.keys())
|
63 |
+
return output
|
64 |
+
elif prefix_sequence[0] in trie_dict:
|
65 |
+
return Trie._get_from_trie(
|
66 |
+
prefix_sequence[1:],
|
67 |
+
trie_dict[prefix_sequence[0]],
|
68 |
+
append_trie,
|
69 |
+
bos_token_id,
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
if append_trie:
|
73 |
+
return append_trie.get(prefix_sequence)
|
74 |
+
else:
|
75 |
+
return []
|
76 |
+
|
77 |
+
def __iter__(self):
|
78 |
+
def _traverse(prefix_sequence, trie_dict):
|
79 |
+
if trie_dict:
|
80 |
+
for next_token in trie_dict:
|
81 |
+
yield from _traverse(
|
82 |
+
prefix_sequence + [next_token], trie_dict[next_token]
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
yield prefix_sequence
|
86 |
+
|
87 |
+
return _traverse([], self.trie_dict)
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return self.len
|
91 |
+
|
92 |
+
def __getitem__(self, value):
|
93 |
+
return self.get(value)
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|