Spaces:
Running
Running
update
Browse files- app.py +1 -1
- lrt/__init__.py +2 -0
- lrt/academic_query/__init__.py +1 -0
- lrt/academic_query/academic.py +35 -0
- lrt/clustering/__init__.py +2 -0
- lrt/clustering/clustering_pipeline.py +99 -0
- lrt/clustering/clusters.py +62 -0
- lrt/clustering/config.py +11 -0
- lrt/clustering/models/__init__.py +1 -0
- lrt/clustering/models/adapter.py +25 -0
- lrt/clustering/models/keyBartPlus.py +411 -0
- lrt/lrt.py +101 -0
- lrt/utils/__init__.py +3 -0
- lrt/utils/article.py +394 -0
- lrt/utils/functions.py +125 -0
- lrt/utils/union_find.py +55 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -16,7 +16,7 @@ with st.form("my_form",clear_on_submit=False):
|
|
16 |
query_input = st.text_input(
|
17 |
'Enter your keyphrases',
|
18 |
placeholder='''e.g. "Machine learning"''',
|
19 |
-
label_visibility='collapsed',
|
20 |
value=''
|
21 |
)
|
22 |
|
|
|
16 |
query_input = st.text_input(
|
17 |
'Enter your keyphrases',
|
18 |
placeholder='''e.g. "Machine learning"''',
|
19 |
+
# label_visibility='collapsed',
|
20 |
value=''
|
21 |
)
|
22 |
|
lrt/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .lrt import LiteratureResearchTool
|
2 |
+
from .clustering import Configuration
|
lrt/academic_query/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .academic import AcademicQuery
|
lrt/academic_query/academic.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from requests_toolkit import ArxivQuery,IEEEQuery,PaperWithCodeQuery
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class AcademicQuery:
|
5 |
+
@classmethod
|
6 |
+
def arxiv(cls,
|
7 |
+
query: str,
|
8 |
+
max_results: int = 50
|
9 |
+
) -> List[dict]:
|
10 |
+
ret = ArxivQuery.query(query,'',0,max_results)
|
11 |
+
if not isinstance(ret,list):
|
12 |
+
return [ret]
|
13 |
+
return ret
|
14 |
+
|
15 |
+
@classmethod
|
16 |
+
def ieee(cls,
|
17 |
+
query: str,
|
18 |
+
start_year: int,
|
19 |
+
end_year: int,
|
20 |
+
num_papers: int = 200
|
21 |
+
) -> List[dict]:
|
22 |
+
IEEEQuery.__setup_api_key__('vpd9yy325enruv27zj2d353e')
|
23 |
+
ret = IEEEQuery.query(query,start_year,end_year,num_papers)
|
24 |
+
if not isinstance(ret,list):
|
25 |
+
return [ret]
|
26 |
+
return ret
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def paper_with_code(cls,
|
30 |
+
query: str,
|
31 |
+
items_per_page = 50) ->List[dict]:
|
32 |
+
ret = PaperWithCodeQuery.query(query, 1,items_per_page)
|
33 |
+
if not isinstance(ret, list):
|
34 |
+
return [ret]
|
35 |
+
return ret
|
lrt/clustering/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .clustering_pipeline import ClusterPipeline, ClusterList
|
2 |
+
from .config import Configuration,BaselineConfig
|
lrt/clustering/clustering_pipeline.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from .config import BaselineConfig, Configuration
|
3 |
+
from ..utils import __create_model__
|
4 |
+
# import numpy as np
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
# from yellowbrick.cluster import KElbowVisualizer
|
7 |
+
from .clusters import ClusterList
|
8 |
+
class ClusterPipeline:
|
9 |
+
def __init__(self, config:Configuration = None):
|
10 |
+
if config is None:
|
11 |
+
self.__setup__(BaselineConfig())
|
12 |
+
else:
|
13 |
+
self.__setup__(config)
|
14 |
+
|
15 |
+
def __setup__(self, config:Configuration):
|
16 |
+
self.PTM = __create_model__(config.plm)
|
17 |
+
self.dimension_reduction = __create_model__(config.dimension_reduction) # TODO
|
18 |
+
self.clustering = __create_model__(config.clustering)
|
19 |
+
self.keywords_extraction = __create_model__(config.keywords_extraction)
|
20 |
+
|
21 |
+
def __1_generate_word_embeddings__(self, documents: List[str]):
|
22 |
+
'''
|
23 |
+
|
24 |
+
:param documents: a list of N strings:
|
25 |
+
:return: np.ndarray: Nx384 (sentence-transformers)
|
26 |
+
'''
|
27 |
+
print(f'>>> start generating word embeddings...')
|
28 |
+
print(f'>>> successfully generated word embeddings...')
|
29 |
+
return self.PTM.encode(documents)
|
30 |
+
|
31 |
+
def __2_dimenstion_reduction__(self, embeddings):
|
32 |
+
'''
|
33 |
+
|
34 |
+
:param embeddings: NxD
|
35 |
+
:return: Nxd, d<<D
|
36 |
+
'''
|
37 |
+
if self.dimension_reduction is None:
|
38 |
+
return embeddings
|
39 |
+
print(f'>>> start dimension reduction...')
|
40 |
+
print(f'>>> finished dimension reduction...')
|
41 |
+
|
42 |
+
def __3_clustering__(self, embeddings, return_cluster_centers = False, best_k: int = 5):
|
43 |
+
'''
|
44 |
+
|
45 |
+
:param embeddings: Nxd
|
46 |
+
:return:
|
47 |
+
'''
|
48 |
+
if self.clustering is None:
|
49 |
+
return embeddings
|
50 |
+
else:
|
51 |
+
print(f'>>> start clustering...')
|
52 |
+
model = KMeans()
|
53 |
+
# visualizer = KElbowVisualizer(
|
54 |
+
# model, k=(2, 12), metric='calinski_harabasz', timings=False, locate_elbow=False
|
55 |
+
# )
|
56 |
+
#
|
57 |
+
# visualizer.fit(embeddings)
|
58 |
+
# best_k = visualizer.k_values_[np.argmax(np.array(visualizer.k_scores_))]
|
59 |
+
# print(f'>>> The best K is {best_k}.')
|
60 |
+
|
61 |
+
labels, cluster_centers = self.clustering(embeddings, k=best_k)
|
62 |
+
clusters = ClusterList(best_k)
|
63 |
+
clusters.instantiate(labels)
|
64 |
+
print(f'>>> finished clustering...')
|
65 |
+
|
66 |
+
if return_cluster_centers:
|
67 |
+
return clusters, cluster_centers
|
68 |
+
return clusters
|
69 |
+
|
70 |
+
def __4_keywords_extraction__(self, clusters: ClusterList, documents: List[str]):
|
71 |
+
'''
|
72 |
+
|
73 |
+
:param clusters: N documents
|
74 |
+
:return: clusters, where each cluster has added keyphrases
|
75 |
+
'''
|
76 |
+
if self.keywords_extraction is None:
|
77 |
+
return clusters
|
78 |
+
else:
|
79 |
+
print(f'>>> start keywords extraction')
|
80 |
+
for cluster in clusters:
|
81 |
+
doc_ids = cluster.elements()
|
82 |
+
input_abstracts = [documents[i] for i in doc_ids] #[str]
|
83 |
+
keyphrases = self.keywords_extraction(input_abstracts) #[{keys...}]
|
84 |
+
cluster.add_keyphrase(keyphrases)
|
85 |
+
# for doc_id in doc_ids:
|
86 |
+
# keyphrases = self.keywords_extraction(documents[doc_id])
|
87 |
+
# cluster.add_keyphrase(keyphrases)
|
88 |
+
print(f'>>> finished keywords extraction')
|
89 |
+
return clusters
|
90 |
+
|
91 |
+
|
92 |
+
def __call__(self, documents: List[str], best_k:int = 5):
|
93 |
+
print(f'>>> pipeline starts...')
|
94 |
+
x = self.__1_generate_word_embeddings__(documents)
|
95 |
+
x = self.__2_dimenstion_reduction__(x)
|
96 |
+
clusters = self.__3_clustering__(x,best_k=best_k)
|
97 |
+
outputs = self.__4_keywords_extraction__(clusters, documents)
|
98 |
+
print(f'>>> pipeline finished!\n')
|
99 |
+
return outputs
|
lrt/clustering/clusters.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Iterable, Union
|
2 |
+
from pprint import pprint
|
3 |
+
|
4 |
+
class SingleCluster:
|
5 |
+
def __init__(self):
|
6 |
+
self.__container__ = []
|
7 |
+
self.__keyphrases__ = {}
|
8 |
+
def add(self, id:int):
|
9 |
+
self.__container__.append(id)
|
10 |
+
def __str__(self) -> str:
|
11 |
+
return str(self.__container__)
|
12 |
+
def elements(self) -> List:
|
13 |
+
return self.__container__
|
14 |
+
def get_keyphrases(self):
|
15 |
+
return self.__keyphrases__
|
16 |
+
def add_keyphrase(self, keyphrase:Union[str,Iterable]):
|
17 |
+
if isinstance(keyphrase,str):
|
18 |
+
if keyphrase not in self.__keyphrases__.keys():
|
19 |
+
self.__keyphrases__[keyphrase] = 1
|
20 |
+
else:
|
21 |
+
self.__keyphrases__[keyphrase] += 1
|
22 |
+
elif isinstance(keyphrase,Iterable):
|
23 |
+
for i in keyphrase:
|
24 |
+
self.add_keyphrase(i)
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.__container__)
|
28 |
+
|
29 |
+
def print_keyphrases(self):
|
30 |
+
pprint(self.__keyphrases__)
|
31 |
+
|
32 |
+
class ClusterList:
|
33 |
+
def __init__(self, k:int):
|
34 |
+
self.__clusters__ = [SingleCluster() for _ in range(k)]
|
35 |
+
|
36 |
+
# subscriptable and slice-able
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
if isinstance(idx, int):
|
39 |
+
return self.__clusters__[idx]
|
40 |
+
if isinstance(idx, slice):
|
41 |
+
# return
|
42 |
+
return self.__clusters__[0 if idx.start is None else idx.start: idx.stop: 0 if idx.step is None else idx.step]
|
43 |
+
|
44 |
+
def instantiate(self, labels: Iterable):
|
45 |
+
for id, label in enumerate(labels):
|
46 |
+
self.__clusters__[label].add(id)
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
ret = f'There are {len(self.__clusters__)} clusters:\n'
|
50 |
+
for id,cluster in enumerate(self.__clusters__):
|
51 |
+
ret += f'cluster {id} contains: {cluster}.\n'
|
52 |
+
|
53 |
+
return ret
|
54 |
+
|
55 |
+
# return an iterator that can be used in for loop etc.
|
56 |
+
def __iter__(self):
|
57 |
+
return self.__clusters__.__iter__()
|
58 |
+
|
59 |
+
def __len__(self): return len(self.__clusters__)
|
60 |
+
|
61 |
+
def sort(self):
|
62 |
+
self.__clusters__.sort(key=len,reverse=True)
|
lrt/clustering/config.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Configuration:
|
2 |
+
def __init__(self, plm:str, dimension_reduction:str,clustering:str,keywords_extraction:str):
|
3 |
+
self.plm = plm
|
4 |
+
self.dimension_reduction = dimension_reduction
|
5 |
+
self.clustering = clustering
|
6 |
+
self.keywords_extraction = keywords_extraction
|
7 |
+
|
8 |
+
|
9 |
+
class BaselineConfig(Configuration):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__('''all-mpnet-base-v2''', 'none', 'kmeans-euclidean', 'keyphrase-transformer')
|
lrt/clustering/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .keyBartPlus import KeyBartAdapter
|
lrt/clustering/models/adapter.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Adapter(nn.Module):
|
5 |
+
def __init__(self,input_dim:int, hidden_dim: int) -> None:
|
6 |
+
super().__init__()
|
7 |
+
self.input_dim = input_dim
|
8 |
+
self.hidden_dim = hidden_dim
|
9 |
+
self.layerNorm = nn.LayerNorm(input_dim)
|
10 |
+
self.down_proj = nn.Linear(input_dim,hidden_dim,False)
|
11 |
+
self.up_proj = nn.Linear(hidden_dim,input_dim,False)
|
12 |
+
|
13 |
+
def forward(self,x):
|
14 |
+
'''
|
15 |
+
|
16 |
+
:param x: N,L,D
|
17 |
+
:return: N,L,D
|
18 |
+
'''
|
19 |
+
output = x
|
20 |
+
x = self.layerNorm(x)
|
21 |
+
x = self.down_proj(x)
|
22 |
+
x = nn.functional.relu(x)
|
23 |
+
x = self.up_proj(x)
|
24 |
+
output = output + x # residual connection
|
25 |
+
return output
|
lrt/clustering/models/keyBartPlus.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Union, Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import random
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
|
7 |
+
from transformers.utils import (
|
8 |
+
add_start_docstrings_to_model_forward,
|
9 |
+
add_end_docstrings,
|
10 |
+
replace_return_docstrings
|
11 |
+
)
|
12 |
+
|
13 |
+
from transformers import AutoModelForSeq2SeqLM
|
14 |
+
from transformers.models.bart.modeling_bart import (
|
15 |
+
BartForConditionalGeneration,
|
16 |
+
_expand_mask, logger,
|
17 |
+
shift_tokens_right,
|
18 |
+
BartPretrainedModel,
|
19 |
+
BART_INPUTS_DOCSTRING,
|
20 |
+
_CONFIG_FOR_DOC,
|
21 |
+
BART_GENERATION_EXAMPLE,
|
22 |
+
BartModel,
|
23 |
+
BartDecoder
|
24 |
+
|
25 |
+
)
|
26 |
+
from .adapter import Adapter
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
29 |
+
Seq2SeqModelOutput,
|
30 |
+
BaseModelOutput,
|
31 |
+
Seq2SeqLMOutput
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class KeyBartAdapter(BartForConditionalGeneration):
|
36 |
+
def __init__(self,adapter_hid_dim:int) -> None:
|
37 |
+
keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART")
|
38 |
+
self.__fix_weights__(keyBart)
|
39 |
+
|
40 |
+
super().__init__(keyBart.model.config)
|
41 |
+
self.lm_head = keyBart.lm_head
|
42 |
+
self.model = BartPlus(keyBart, adapter_hid_dim)
|
43 |
+
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
44 |
+
|
45 |
+
|
46 |
+
def __fix_weights__(self,keyBart:BartForConditionalGeneration):
|
47 |
+
for i in keyBart.model.parameters():
|
48 |
+
i.requires_grad = False
|
49 |
+
for i in keyBart.lm_head.parameters():
|
50 |
+
i.requires_grad = False
|
51 |
+
|
52 |
+
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
53 |
+
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
54 |
+
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
55 |
+
def forward(
|
56 |
+
self,
|
57 |
+
input_ids: torch.LongTensor = None,
|
58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
59 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
60 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
61 |
+
head_mask: Optional[torch.Tensor] = None,
|
62 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
63 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
64 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
65 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
66 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
67 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
68 |
+
labels: Optional[torch.LongTensor] = None,
|
69 |
+
use_cache: Optional[bool] = None,
|
70 |
+
output_attentions: Optional[bool] = None,
|
71 |
+
output_hidden_states: Optional[bool] = None,
|
72 |
+
return_dict: Optional[bool] = None,
|
73 |
+
) -> Union[Tuple, Seq2SeqLMOutput]:
|
74 |
+
r"""
|
75 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
76 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
77 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
78 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
79 |
+
Returns:
|
80 |
+
"""
|
81 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
82 |
+
|
83 |
+
if labels is not None:
|
84 |
+
if use_cache:
|
85 |
+
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
86 |
+
use_cache = False
|
87 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
88 |
+
decoder_input_ids = shift_tokens_right(
|
89 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
90 |
+
)
|
91 |
+
|
92 |
+
outputs = self.model(
|
93 |
+
input_ids,
|
94 |
+
attention_mask=attention_mask,
|
95 |
+
decoder_input_ids=decoder_input_ids,
|
96 |
+
encoder_outputs=encoder_outputs,
|
97 |
+
decoder_attention_mask=decoder_attention_mask,
|
98 |
+
head_mask=head_mask,
|
99 |
+
decoder_head_mask=decoder_head_mask,
|
100 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
101 |
+
past_key_values=past_key_values,
|
102 |
+
inputs_embeds=inputs_embeds,
|
103 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
104 |
+
use_cache=use_cache,
|
105 |
+
output_attentions=output_attentions,
|
106 |
+
output_hidden_states=output_hidden_states,
|
107 |
+
return_dict=return_dict,
|
108 |
+
)
|
109 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
110 |
+
|
111 |
+
masked_lm_loss = None
|
112 |
+
if labels is not None:
|
113 |
+
loss_fct = CrossEntropyLoss()
|
114 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
115 |
+
|
116 |
+
if not return_dict:
|
117 |
+
output = (lm_logits,) + outputs[1:]
|
118 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
119 |
+
|
120 |
+
return Seq2SeqLMOutput(
|
121 |
+
loss=masked_lm_loss,
|
122 |
+
logits=lm_logits,
|
123 |
+
past_key_values=outputs.past_key_values,
|
124 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
125 |
+
decoder_attentions=outputs.decoder_attentions,
|
126 |
+
cross_attentions=outputs.cross_attentions,
|
127 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
128 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
129 |
+
encoder_attentions=outputs.encoder_attentions,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class BartDecoderPlus(BartDecoder):
|
135 |
+
def __init__(self,keyBart:BartForConditionalGeneration,adapter_hid_dim: int) -> None:
|
136 |
+
super().__init__(keyBart.get_decoder().config)
|
137 |
+
self.decoder = keyBart.model.decoder
|
138 |
+
self.adapters = nn.ModuleList([Adapter(self.decoder.config.d_model,adapter_hid_dim) for _ in range(len(self.decoder.layers))])
|
139 |
+
self.config = self.decoder.config
|
140 |
+
self.dropout = self.decoder.dropout
|
141 |
+
self.layerdrop = self.decoder.layerdrop
|
142 |
+
self.padding_idx = self.decoder.padding_idx
|
143 |
+
self.max_target_positions = self.decoder.max_target_positions
|
144 |
+
self.embed_scale = self.decoder.embed_scale
|
145 |
+
self.embed_tokens = self.decoder.embed_tokens
|
146 |
+
self.embed_positions = self.decoder.embed_positions
|
147 |
+
self.layers = self.decoder.layers
|
148 |
+
self.layernorm_embedding = self.decoder.layernorm_embedding
|
149 |
+
self.gradient_checkpointing = self.decoder.gradient_checkpointing
|
150 |
+
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
input_ids: torch.LongTensor = None,
|
155 |
+
attention_mask: Optional[torch.Tensor] = None,
|
156 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
157 |
+
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
158 |
+
head_mask: Optional[torch.Tensor] = None,
|
159 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
160 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
161 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
162 |
+
use_cache: Optional[bool] = None,
|
163 |
+
output_attentions: Optional[bool] = None,
|
164 |
+
output_hidden_states: Optional[bool] = None,
|
165 |
+
return_dict: Optional[bool] = None,
|
166 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
167 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
168 |
+
output_hidden_states = (
|
169 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
170 |
+
)
|
171 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
172 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
173 |
+
|
174 |
+
# retrieve input_ids and inputs_embeds
|
175 |
+
if input_ids is not None and inputs_embeds is not None:
|
176 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
177 |
+
elif input_ids is not None:
|
178 |
+
input = input_ids
|
179 |
+
input_shape = input.shape
|
180 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
181 |
+
elif inputs_embeds is not None:
|
182 |
+
input_shape = inputs_embeds.size()[:-1]
|
183 |
+
input = inputs_embeds[:, :, -1]
|
184 |
+
else:
|
185 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
186 |
+
|
187 |
+
# past_key_values_length
|
188 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
189 |
+
|
190 |
+
if inputs_embeds is None:
|
191 |
+
inputs_embeds = self.decoder.embed_tokens(input) * self.decoder.embed_scale
|
192 |
+
|
193 |
+
attention_mask = self.decoder._prepare_decoder_attention_mask(
|
194 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
195 |
+
)
|
196 |
+
|
197 |
+
# expand encoder attention mask
|
198 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
199 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
200 |
+
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
201 |
+
|
202 |
+
# embed positions
|
203 |
+
positions = self.decoder.embed_positions(input, past_key_values_length)
|
204 |
+
|
205 |
+
hidden_states = inputs_embeds + positions
|
206 |
+
hidden_states = self.decoder.layernorm_embedding(hidden_states)
|
207 |
+
|
208 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.decoder.dropout, training=self.decoder.training)
|
209 |
+
|
210 |
+
# decoder layers
|
211 |
+
all_hidden_states = () if output_hidden_states else None
|
212 |
+
all_self_attns = () if output_attentions else None
|
213 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
214 |
+
next_decoder_cache = () if use_cache else None
|
215 |
+
|
216 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
217 |
+
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
218 |
+
if attn_mask is not None:
|
219 |
+
if attn_mask.size()[0] != (len(self.decoder.layers)):
|
220 |
+
raise ValueError(
|
221 |
+
f"The `{mask_name}` should be specified for {len(self.decoder.layers)} layers, but it is for"
|
222 |
+
f" {head_mask.size()[0]}."
|
223 |
+
)
|
224 |
+
|
225 |
+
for idx, decoder_layer in enumerate(self.decoder.layers):
|
226 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
227 |
+
if output_hidden_states:
|
228 |
+
all_hidden_states += (hidden_states,)
|
229 |
+
dropout_probability = random.uniform(0, 1)
|
230 |
+
if self.decoder.training and (dropout_probability < self.decoder.layerdrop):
|
231 |
+
continue
|
232 |
+
|
233 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
234 |
+
|
235 |
+
if self.decoder.gradient_checkpointing and self.decoder.training:
|
236 |
+
|
237 |
+
if use_cache:
|
238 |
+
logger.warning(
|
239 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
240 |
+
)
|
241 |
+
use_cache = False
|
242 |
+
|
243 |
+
def create_custom_forward(module):
|
244 |
+
def custom_forward(*inputs):
|
245 |
+
# None for past_key_value
|
246 |
+
return module(*inputs, output_attentions, use_cache)
|
247 |
+
|
248 |
+
return custom_forward
|
249 |
+
|
250 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
251 |
+
create_custom_forward(decoder_layer),
|
252 |
+
hidden_states,
|
253 |
+
attention_mask,
|
254 |
+
encoder_hidden_states,
|
255 |
+
encoder_attention_mask,
|
256 |
+
head_mask[idx] if head_mask is not None else None,
|
257 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
258 |
+
None,
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
|
262 |
+
layer_outputs = decoder_layer(
|
263 |
+
hidden_states,
|
264 |
+
attention_mask=attention_mask,
|
265 |
+
encoder_hidden_states=encoder_hidden_states,
|
266 |
+
encoder_attention_mask=encoder_attention_mask,
|
267 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
268 |
+
cross_attn_layer_head_mask=(
|
269 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
270 |
+
),
|
271 |
+
past_key_value=past_key_value,
|
272 |
+
output_attentions=output_attentions,
|
273 |
+
use_cache=use_cache,
|
274 |
+
)
|
275 |
+
hidden_states = layer_outputs[0]
|
276 |
+
|
277 |
+
######################### new #################################
|
278 |
+
hidden_states = self.adapters[idx](hidden_states)
|
279 |
+
######################### new #################################
|
280 |
+
|
281 |
+
if use_cache:
|
282 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
283 |
+
|
284 |
+
if output_attentions:
|
285 |
+
all_self_attns += (layer_outputs[1],)
|
286 |
+
|
287 |
+
if encoder_hidden_states is not None:
|
288 |
+
all_cross_attentions += (layer_outputs[2],)
|
289 |
+
|
290 |
+
# add hidden states from the last decoder layer
|
291 |
+
if output_hidden_states:
|
292 |
+
all_hidden_states += (hidden_states,)
|
293 |
+
|
294 |
+
next_cache = next_decoder_cache if use_cache else None
|
295 |
+
if not return_dict:
|
296 |
+
return tuple(
|
297 |
+
v
|
298 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
299 |
+
if v is not None
|
300 |
+
)
|
301 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
302 |
+
last_hidden_state=hidden_states,
|
303 |
+
past_key_values=next_cache,
|
304 |
+
hidden_states=all_hidden_states,
|
305 |
+
attentions=all_self_attns,
|
306 |
+
cross_attentions=all_cross_attentions,
|
307 |
+
)
|
308 |
+
|
309 |
+
class BartPlus(BartModel):
|
310 |
+
def __init__(self,keyBart: BartForConditionalGeneration, adapter_hid_dim: int ) -> None:
|
311 |
+
super().__init__(keyBart.model.config)
|
312 |
+
self.config = keyBart.model.config
|
313 |
+
|
314 |
+
# self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
315 |
+
self.shared = keyBart.model.shared
|
316 |
+
|
317 |
+
#self.encoder = BartEncoder(config, self.shared)
|
318 |
+
self.encoder = keyBart.model.encoder
|
319 |
+
|
320 |
+
#self.decoder = BartDecoder(config, self.shared)
|
321 |
+
#self.decoder = keyBart.model.decoder
|
322 |
+
self.decoder = BartDecoderPlus(keyBart,adapter_hid_dim=adapter_hid_dim)
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self,
|
326 |
+
input_ids: torch.LongTensor = None,
|
327 |
+
attention_mask: Optional[torch.Tensor] = None,
|
328 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
329 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
330 |
+
head_mask: Optional[torch.Tensor] = None,
|
331 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
332 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
333 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
334 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
335 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
336 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
337 |
+
use_cache: Optional[bool] = None,
|
338 |
+
output_attentions: Optional[bool] = None,
|
339 |
+
output_hidden_states: Optional[bool] = None,
|
340 |
+
return_dict: Optional[bool] = None,
|
341 |
+
) -> Union[Tuple, Seq2SeqModelOutput]:
|
342 |
+
|
343 |
+
# different to other models, Bart automatically creates decoder_input_ids from
|
344 |
+
# input_ids if no decoder_input_ids are provided
|
345 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
346 |
+
if input_ids is None:
|
347 |
+
raise ValueError(
|
348 |
+
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
349 |
+
"passed, `input_ids` cannot be `None`. Please pass either "
|
350 |
+
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
351 |
+
)
|
352 |
+
|
353 |
+
decoder_input_ids = shift_tokens_right(
|
354 |
+
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
|
355 |
+
)
|
356 |
+
|
357 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
358 |
+
output_hidden_states = (
|
359 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
360 |
+
)
|
361 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
362 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
363 |
+
|
364 |
+
if encoder_outputs is None:
|
365 |
+
encoder_outputs = self.encoder(
|
366 |
+
input_ids=input_ids,
|
367 |
+
attention_mask=attention_mask,
|
368 |
+
head_mask=head_mask,
|
369 |
+
inputs_embeds=inputs_embeds,
|
370 |
+
output_attentions=output_attentions,
|
371 |
+
output_hidden_states=output_hidden_states,
|
372 |
+
return_dict=return_dict,
|
373 |
+
)
|
374 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
375 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
376 |
+
encoder_outputs = BaseModelOutput(
|
377 |
+
last_hidden_state=encoder_outputs[0],
|
378 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
379 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
380 |
+
)
|
381 |
+
|
382 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
383 |
+
decoder_outputs = self.decoder(
|
384 |
+
input_ids=decoder_input_ids,
|
385 |
+
attention_mask=decoder_attention_mask,
|
386 |
+
encoder_hidden_states=encoder_outputs[0],
|
387 |
+
encoder_attention_mask=attention_mask,
|
388 |
+
head_mask=decoder_head_mask,
|
389 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
390 |
+
past_key_values=past_key_values,
|
391 |
+
inputs_embeds=decoder_inputs_embeds,
|
392 |
+
use_cache=use_cache,
|
393 |
+
output_attentions=output_attentions,
|
394 |
+
output_hidden_states=output_hidden_states,
|
395 |
+
return_dict=return_dict,
|
396 |
+
)
|
397 |
+
|
398 |
+
if not return_dict:
|
399 |
+
return decoder_outputs + encoder_outputs
|
400 |
+
|
401 |
+
return Seq2SeqModelOutput(
|
402 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
403 |
+
past_key_values=decoder_outputs.past_key_values,
|
404 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
405 |
+
decoder_attentions=decoder_outputs.attentions,
|
406 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
407 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
408 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
409 |
+
encoder_attentions=encoder_outputs.attentions,
|
410 |
+
)
|
411 |
+
|
lrt/lrt.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clustering import *
|
2 |
+
from typing import List
|
3 |
+
import textdistance as td
|
4 |
+
from .utils import UnionFind, ArticleList
|
5 |
+
from .academic_query import AcademicQuery
|
6 |
+
|
7 |
+
class LiteratureResearchTool:
|
8 |
+
def __init__(self, cluster_config: Configuration = None):
|
9 |
+
self.literature_search = AcademicQuery
|
10 |
+
self.cluster_pipeline = ClusterPipeline(cluster_config)
|
11 |
+
|
12 |
+
def __postprocess_clusters__(self, clusters: ClusterList) ->ClusterList:
|
13 |
+
'''
|
14 |
+
add top-5 keyphrases to each cluster
|
15 |
+
:param clusters:
|
16 |
+
:return: clusters
|
17 |
+
'''
|
18 |
+
def condition(x, y):
|
19 |
+
return td.ratcliff_obershelp(x, y) > 0.8
|
20 |
+
|
21 |
+
def valid_keyphrase(x:str):
|
22 |
+
return x is not None and x != '' and not x.isspace()
|
23 |
+
|
24 |
+
for cluster in clusters:
|
25 |
+
cluster.top_5_keyphrases = []
|
26 |
+
keyphrases = cluster.get_keyphrases()
|
27 |
+
keyphrases = list(keyphrases.keys())
|
28 |
+
keyphrases = list(filter(valid_keyphrase,keyphrases))
|
29 |
+
unionfind = UnionFind(keyphrases, condition)
|
30 |
+
unionfind.union_step()
|
31 |
+
|
32 |
+
keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
|
33 |
+
|
34 |
+
for i in keyphrases:
|
35 |
+
tmp = '/'.join(i)
|
36 |
+
cluster.top_5_keyphrases.append(tmp)
|
37 |
+
|
38 |
+
return clusters
|
39 |
+
|
40 |
+
def __call__(self,
|
41 |
+
query: str,
|
42 |
+
num_papers: int,
|
43 |
+
start_year: int,
|
44 |
+
end_year: int,
|
45 |
+
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
|
46 |
+
best_k: int = 5,
|
47 |
+
loading_ctx_manager = None,
|
48 |
+
decorator: callable = None
|
49 |
+
):
|
50 |
+
|
51 |
+
|
52 |
+
for platform in platforms:
|
53 |
+
if loading_ctx_manager:
|
54 |
+
with loading_ctx_manager:
|
55 |
+
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,best_k)
|
56 |
+
else:
|
57 |
+
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,best_k)
|
58 |
+
|
59 |
+
clusters.sort()
|
60 |
+
yield clusters,articles
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def __platformPipeline__(self,platforn_name:str,
|
65 |
+
query: str,
|
66 |
+
num_papers: int,
|
67 |
+
start_year: int,
|
68 |
+
end_year: int,
|
69 |
+
best_k: int = 5
|
70 |
+
) -> (ClusterList,ArticleList):
|
71 |
+
if platforn_name == 'IEEE':
|
72 |
+
articles = ArticleList.parse_ieee_articles(self.literature_search.ieee(query,start_year,end_year,num_papers)) # ArticleList
|
73 |
+
abstracts = articles.getAbstracts() # List[str]
|
74 |
+
clusters = self.cluster_pipeline(abstracts, best_k=best_k)
|
75 |
+
clusters = self.__postprocess_clusters__(clusters)
|
76 |
+
return clusters,articles
|
77 |
+
elif platforn_name == 'Arxiv':
|
78 |
+
articles = ArticleList.parse_arxiv_articles(
|
79 |
+
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
80 |
+
abstracts = articles.getAbstracts() # List[str]
|
81 |
+
clusters = self.cluster_pipeline(abstracts,best_k=best_k)
|
82 |
+
clusters = self.__postprocess_clusters__(clusters)
|
83 |
+
return clusters, articles
|
84 |
+
elif platforn_name == 'Paper with Code':
|
85 |
+
articles = ArticleList.parse_pwc_articles(
|
86 |
+
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
87 |
+
abstracts = articles.getAbstracts() # List[str]
|
88 |
+
clusters = self.cluster_pipeline(abstracts,best_k=best_k)
|
89 |
+
clusters = self.__postprocess_clusters__(clusters)
|
90 |
+
return clusters, articles
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
lrt/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .functions import __create_model__
|
2 |
+
from .union_find import UnionFind
|
3 |
+
from .article import ArticleList, Article
|
lrt/utils/article.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Optional
|
2 |
+
class Article:
|
3 |
+
'''
|
4 |
+
attributes:
|
5 |
+
- title: str
|
6 |
+
- authors: list of str
|
7 |
+
- abstract: str
|
8 |
+
- url: str
|
9 |
+
- publication_year: int
|
10 |
+
'''
|
11 |
+
def __init__(self,
|
12 |
+
title: str,
|
13 |
+
authors: List[str],
|
14 |
+
abstract: str,
|
15 |
+
url: str,
|
16 |
+
publication_year: int
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
self.title = title
|
20 |
+
self.authors = authors
|
21 |
+
self.url = url
|
22 |
+
self.publication_year = publication_year
|
23 |
+
self.abstract = abstract.replace('\n',' ')
|
24 |
+
def __str__(self):
|
25 |
+
ret = ''
|
26 |
+
ret +=self.title +'\n- '
|
27 |
+
ret +=f"authors: {';'.join(self.authors)}" + '\n- '
|
28 |
+
ret += f'''abstract: {self.abstract}''' + '\n- '
|
29 |
+
ret += f'''url: {self.url}'''+ '\n- '
|
30 |
+
ret += f'''publication year: {self.publication_year}'''+ '\n\n'
|
31 |
+
|
32 |
+
return ret
|
33 |
+
|
34 |
+
class ArticleList:
|
35 |
+
'''
|
36 |
+
list of articles
|
37 |
+
'''
|
38 |
+
def __init__(self,articles:Optional[Union[Article, List[Article]]]=None) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.__list__ = [] # List[Article]
|
41 |
+
|
42 |
+
if articles is not None:
|
43 |
+
self.addArticles(articles)
|
44 |
+
|
45 |
+
def addArticles(self, articles:Union[Article, List[Article]]):
|
46 |
+
if isinstance(articles,Article):
|
47 |
+
self.__list__.append(articles)
|
48 |
+
elif isinstance(articles, list):
|
49 |
+
self.__list__ += articles
|
50 |
+
|
51 |
+
# subscriptable and slice-able
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
if isinstance(idx, int):
|
54 |
+
return self.__list__[idx]
|
55 |
+
if isinstance(idx, slice):
|
56 |
+
# return
|
57 |
+
return self.__list__[0 if idx.start is None else idx.start: idx.stop: 0 if idx.step is None else idx.step]
|
58 |
+
|
59 |
+
|
60 |
+
def __str__(self):
|
61 |
+
ret = f'There are {len(self.__list__)} articles:\n'
|
62 |
+
for id, article in enumerate(self.__list__):
|
63 |
+
ret += f'{id+1}) '
|
64 |
+
ret += f'{article}'
|
65 |
+
|
66 |
+
return ret
|
67 |
+
|
68 |
+
# return an iterator that can be used in for loop etc.
|
69 |
+
def __iter__(self):
|
70 |
+
return self.__list__.__iter__()
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.__list__)
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def parse_ieee_articles(cls,items: Union[dict, List[dict]]):
|
77 |
+
if isinstance(items,dict):
|
78 |
+
items = [items]
|
79 |
+
|
80 |
+
ret = [
|
81 |
+
Article(
|
82 |
+
title=item['title'],
|
83 |
+
authors=[x['full_name'] for x in item['authors']['authors']],
|
84 |
+
abstract=item['abstract'],
|
85 |
+
url=item['html_url'],
|
86 |
+
publication_year=item['publication_year']
|
87 |
+
)
|
88 |
+
for item in items ] # List[Article]
|
89 |
+
|
90 |
+
ret = ArticleList(ret)
|
91 |
+
return ret
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def parse_arxiv_articles(cls, items: Union[dict, List[dict]]):
|
95 |
+
if isinstance(items, dict):
|
96 |
+
items = [items]
|
97 |
+
|
98 |
+
def __getAuthors__(item):
|
99 |
+
if isinstance(item['author'],list):
|
100 |
+
return [x['name'] for x in item['author']]
|
101 |
+
else:
|
102 |
+
return [item['author']['name']]
|
103 |
+
|
104 |
+
ret = [
|
105 |
+
Article(
|
106 |
+
title=item['title'],
|
107 |
+
authors=__getAuthors__(item),
|
108 |
+
abstract=item['summary'],
|
109 |
+
url=item['id'],
|
110 |
+
publication_year=item['published'][:4]
|
111 |
+
)
|
112 |
+
for item in items] # List[Article]
|
113 |
+
|
114 |
+
ret = ArticleList(ret)
|
115 |
+
return ret
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def parse_pwc_articles(cls, items: Union[dict, List[dict]]):
|
119 |
+
if isinstance(items, dict):
|
120 |
+
items = [items]
|
121 |
+
|
122 |
+
ret = [
|
123 |
+
Article(
|
124 |
+
title=item['title'],
|
125 |
+
authors=item['authors'],
|
126 |
+
abstract=item['abstract'],
|
127 |
+
url=item['url_abs'],
|
128 |
+
publication_year=item['published'][:4]
|
129 |
+
)
|
130 |
+
for item in items] # List[Article]
|
131 |
+
|
132 |
+
ret = ArticleList(ret)
|
133 |
+
return ret
|
134 |
+
|
135 |
+
def getAbstracts(self) -> List[str]:
|
136 |
+
return [x.abstract for x in self.__list__]
|
137 |
+
|
138 |
+
def getTitles(self) -> List[str]:
|
139 |
+
return [x.title for x in self.__list__]
|
140 |
+
|
141 |
+
def getArticles(self) -> List[Article]:
|
142 |
+
return self.__list__
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
item = [{'doi': '10.1109/COMPSAC51774.2021.00100',
|
146 |
+
'title': 'Towards Developing An EMR in Mental Health Care for Children’s Mental Health Development among the Underserved Communities in USA',
|
147 |
+
'publisher': 'IEEE',
|
148 |
+
'isbn': '978-1-6654-2464-6',
|
149 |
+
'issn': '0730-3157',
|
150 |
+
'rank': 1,
|
151 |
+
'authors': {'authors': [{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
152 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088961521',
|
153 |
+
'id': 37088961521,
|
154 |
+
'full_name': 'Kazi Zawad Arefin',
|
155 |
+
'author_order': 1},
|
156 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
157 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962639',
|
158 |
+
'id': 37088962639,
|
159 |
+
'full_name': 'Kazi Shafiul Alam Shuvo',
|
160 |
+
'author_order': 2},
|
161 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
162 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088511010',
|
163 |
+
'id': 37088511010,
|
164 |
+
'full_name': 'Masud Rabbani',
|
165 |
+
'author_order': 3},
|
166 |
+
{'affiliation': 'Product Developer, Marquette Energy Analytics, Milwaukee, WI, USA',
|
167 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088961612',
|
168 |
+
'id': 37088961612,
|
169 |
+
'full_name': 'Peter Dobbs',
|
170 |
+
'author_order': 4},
|
171 |
+
{'affiliation': 'Next Step Clinic, Mental Health America of WI, USA',
|
172 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962516',
|
173 |
+
'id': 37088962516,
|
174 |
+
'full_name': 'Leah Jepson',
|
175 |
+
'author_order': 5},
|
176 |
+
{'affiliation': 'Next Step Clinic, Mental Health America of WI, USA',
|
177 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962336',
|
178 |
+
'id': 37088962336,
|
179 |
+
'full_name': 'Amy Leventhal',
|
180 |
+
'author_order': 6},
|
181 |
+
{'affiliation': 'Department of Psychology, Marquette University, USA',
|
182 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962101',
|
183 |
+
'id': 37088962101,
|
184 |
+
'full_name': 'Amy Vaughan Van Heeke',
|
185 |
+
'author_order': 7},
|
186 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
187 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37270354900',
|
188 |
+
'id': 37270354900,
|
189 |
+
'full_name': 'Sheikh Iqbal Ahamed',
|
190 |
+
'author_order': 8}]},
|
191 |
+
'access_type': 'LOCKED',
|
192 |
+
'content_type': 'Conferences',
|
193 |
+
'abstract': "Next Step Clinic (NSC) is a neighborhood-based mental clinic in Milwaukee in the USA for early identification and intervention of Autism spectrum disorder (ASD) children. NSC's primary goal is to serve the underserved families in that area with children aged 15 months to 10 years who have ASD symptoms free of cost. Our proposed and implemented Electronic Medical Records (NSC: EMR) has been developed for NSC. This paper describes the NSC: EMR's design specification and whole development process with the workflow control of this system in NSC. This NSC: EMR has been used to record the patient’s medical data and make appointments both physically or virtually. The integration of standardized psychological evaluation form has reduced the paperwork and physical storage burden for the family navigator. By deploying the system, the family navigator can increase their productivity from the screening to all intervention processes to deal with ASD children. Even in the lockdown time, due to the pandemic of COVID-19, about 84 ASD patients from the deprived family at that area got registered and took intervention through this NSC: EMR. The usability and cost-effective feature has already shown the potential of NSC: EMR, and it will be scaled to serve a large population in the USA and beyond.",
|
194 |
+
'article_number': '9529808',
|
195 |
+
'pdf_url': 'https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9529808',
|
196 |
+
'html_url': 'https://ieeexplore.ieee.org/document/9529808/',
|
197 |
+
'abstract_url': 'https://ieeexplore.ieee.org/document/9529808/',
|
198 |
+
'publication_title': '2021 IEEE 45th Annual Computers, Software, and Applications Conference (COMPSAC)',
|
199 |
+
'conference_location': 'Madrid, Spain',
|
200 |
+
'conference_dates': '12-16 July 2021',
|
201 |
+
'publication_number': 9529349,
|
202 |
+
'is_number': 9529356,
|
203 |
+
'publication_year': 2021,
|
204 |
+
'publication_date': '12-16 July 2021',
|
205 |
+
'start_page': '688',
|
206 |
+
'end_page': '693',
|
207 |
+
'citing_paper_count': 2,
|
208 |
+
'citing_patent_count': 0,
|
209 |
+
'index_terms': {'ieee_terms': {'terms': ['Pediatrics',
|
210 |
+
'Pandemics',
|
211 |
+
'Navigation',
|
212 |
+
'Mental health',
|
213 |
+
'Tools',
|
214 |
+
'Software',
|
215 |
+
'Information technology']},
|
216 |
+
'author_terms': {'terms': ['Electronic medical record (EMR)',
|
217 |
+
'Mental Health Care (MHC)',
|
218 |
+
'Autism Spectrum Disorder (ASD)',
|
219 |
+
'Health Information Technology (HIT)',
|
220 |
+
'Mental Health Professional (MHP)']}},
|
221 |
+
'isbn_formats': {'isbns': [{'format': 'Print on Demand(PoD) ISBN',
|
222 |
+
'value': '978-1-6654-2464-6',
|
223 |
+
'isbnType': 'New-2005'},
|
224 |
+
{'format': 'Electronic ISBN',
|
225 |
+
'value': '978-1-6654-2463-9',
|
226 |
+
'isbnType': 'New-2005'}]}},{'doi': '10.1109/COMPSAC51774.2021.00100',
|
227 |
+
'title': 'Towards Developing An EMR in Mental Health Care for Children’s Mental Health Development among the Underserved Communities in USA',
|
228 |
+
'publisher': 'IEEE',
|
229 |
+
'isbn': '978-1-6654-2464-6',
|
230 |
+
'issn': '0730-3157',
|
231 |
+
'rank': 1,
|
232 |
+
'authors': {'authors': [{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
233 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088961521',
|
234 |
+
'id': 37088961521,
|
235 |
+
'full_name': 'Kazi Zawad Arefin',
|
236 |
+
'author_order': 1},
|
237 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
238 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962639',
|
239 |
+
'id': 37088962639,
|
240 |
+
'full_name': 'Kazi Shafiul Alam Shuvo',
|
241 |
+
'author_order': 2},
|
242 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
243 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088511010',
|
244 |
+
'id': 37088511010,
|
245 |
+
'full_name': 'Masud Rabbani',
|
246 |
+
'author_order': 3},
|
247 |
+
{'affiliation': 'Product Developer, Marquette Energy Analytics, Milwaukee, WI, USA',
|
248 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088961612',
|
249 |
+
'id': 37088961612,
|
250 |
+
'full_name': 'Peter Dobbs',
|
251 |
+
'author_order': 4},
|
252 |
+
{'affiliation': 'Next Step Clinic, Mental Health America of WI, USA',
|
253 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962516',
|
254 |
+
'id': 37088962516,
|
255 |
+
'full_name': 'Leah Jepson',
|
256 |
+
'author_order': 5},
|
257 |
+
{'affiliation': 'Next Step Clinic, Mental Health America of WI, USA',
|
258 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962336',
|
259 |
+
'id': 37088962336,
|
260 |
+
'full_name': 'Amy Leventhal',
|
261 |
+
'author_order': 6},
|
262 |
+
{'affiliation': 'Department of Psychology, Marquette University, USA',
|
263 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37088962101',
|
264 |
+
'id': 37088962101,
|
265 |
+
'full_name': 'Amy Vaughan Van Heeke',
|
266 |
+
'author_order': 7},
|
267 |
+
{'affiliation': 'Department of Computer Science, Ubicomp Lab, Marquette University, Milwaukee, WI, USA',
|
268 |
+
'authorUrl': 'https://ieeexplore.ieee.org/author/37270354900',
|
269 |
+
'id': 37270354900,
|
270 |
+
'full_name': 'Sheikh Iqbal Ahamed',
|
271 |
+
'author_order': 8}]},
|
272 |
+
'access_type': 'LOCKED',
|
273 |
+
'content_type': 'Conferences',
|
274 |
+
'abstract': "Next Step Clinic (NSC) is a neighborhood-based mental clinic in Milwaukee in the USA for early identification and intervention of Autism spectrum disorder (ASD) children. NSC's primary goal is to serve the underserved families in that area with children aged 15 months to 10 years who have ASD symptoms free of cost. Our proposed and implemented Electronic Medical Records (NSC: EMR) has been developed for NSC. This paper describes the NSC: EMR's design specification and whole development process with the workflow control of this system in NSC. This NSC: EMR has been used to record the patient’s medical data and make appointments both physically or virtually. The integration of standardized psychological evaluation form has reduced the paperwork and physical storage burden for the family navigator. By deploying the system, the family navigator can increase their productivity from the screening to all intervention processes to deal with ASD children. Even in the lockdown time, due to the pandemic of COVID-19, about 84 ASD patients from the deprived family at that area got registered and took intervention through this NSC: EMR. The usability and cost-effective feature has already shown the potential of NSC: EMR, and it will be scaled to serve a large population in the USA and beyond.",
|
275 |
+
'article_number': '9529808',
|
276 |
+
'pdf_url': 'https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9529808',
|
277 |
+
'html_url': 'https://ieeexplore.ieee.org/document/9529808/',
|
278 |
+
'abstract_url': 'https://ieeexplore.ieee.org/document/9529808/',
|
279 |
+
'publication_title': '2021 IEEE 45th Annual Computers, Software, and Applications Conference (COMPSAC)',
|
280 |
+
'conference_location': 'Madrid, Spain',
|
281 |
+
'conference_dates': '12-16 July 2021',
|
282 |
+
'publication_number': 9529349,
|
283 |
+
'is_number': 9529356,
|
284 |
+
'publication_year': 2021,
|
285 |
+
'publication_date': '12-16 July 2021',
|
286 |
+
'start_page': '688',
|
287 |
+
'end_page': '693',
|
288 |
+
'citing_paper_count': 2,
|
289 |
+
'citing_patent_count': 0,
|
290 |
+
'index_terms': {'ieee_terms': {'terms': ['Pediatrics',
|
291 |
+
'Pandemics',
|
292 |
+
'Navigation',
|
293 |
+
'Mental health',
|
294 |
+
'Tools',
|
295 |
+
'Software',
|
296 |
+
'Information technology']},
|
297 |
+
'author_terms': {'terms': ['Electronic medical record (EMR)',
|
298 |
+
'Mental Health Care (MHC)',
|
299 |
+
'Autism Spectrum Disorder (ASD)',
|
300 |
+
'Health Information Technology (HIT)',
|
301 |
+
'Mental Health Professional (MHP)']}},
|
302 |
+
'isbn_formats': {'isbns': [{'format': 'Print on Demand(PoD) ISBN',
|
303 |
+
'value': '978-1-6654-2464-6',
|
304 |
+
'isbnType': 'New-2005'},
|
305 |
+
{'format': 'Electronic ISBN',
|
306 |
+
'value': '978-1-6654-2463-9',
|
307 |
+
'isbnType': 'New-2005'}]}}]
|
308 |
+
ieee_articles = ArticleList.parse_ieee_articles(item)
|
309 |
+
print(ieee_articles)
|
310 |
+
|
311 |
+
item = [{'id': 'http://arxiv.org/abs/2106.08047v1',
|
312 |
+
'updated': '2021-06-15T11:07:51Z',
|
313 |
+
'published': '2021-06-15T11:07:51Z',
|
314 |
+
'title': 'Comparisons of Australian Mental Health Distributions',
|
315 |
+
'summary': 'Bayesian nonparametric estimates of Australian mental health distributions\nare obtained to assess how the mental health status of the population has\nchanged over time and to compare the mental health status of female/male and\nindigenous/non-indigenous population subgroups. First- and second-order\nstochastic dominance are used to compare distributions, with results presented\nin terms of the posterior probability of dominance and the posterior\nprobability of no dominance. Our results suggest mental health has deteriorated\nin recent years, that males mental health status is better than that of\nfemales, and non-indigenous health status is better than that of the indigenous\npopulation.',
|
316 |
+
'author': [{'name': 'David Gunawan'},
|
317 |
+
{'name': 'William Griffiths'},
|
318 |
+
{'name': 'Duangkamon Chotikapanich'}],
|
319 |
+
'link': [{'@href': 'http://arxiv.org/abs/2106.08047v1',
|
320 |
+
'@rel': 'alternate',
|
321 |
+
'@type': 'text/html'},
|
322 |
+
{'@title': 'pdf',
|
323 |
+
'@href': 'http://arxiv.org/pdf/2106.08047v1',
|
324 |
+
'@rel': 'related',
|
325 |
+
'@type': 'application/pdf'}],
|
326 |
+
'arxiv:primary_category': {'@xmlns:arxiv': 'http://arxiv.org/schemas/atom',
|
327 |
+
'@term': 'econ.EM',
|
328 |
+
'@scheme': 'http://arxiv.org/schemas/atom'},
|
329 |
+
'category': {'@term': 'econ.EM', '@scheme': 'http://arxiv.org/schemas/atom'}},
|
330 |
+
{'id': 'http://arxiv.org/abs/2106.08047v1',
|
331 |
+
'updated': '2021-06-15T11:07:51Z',
|
332 |
+
'published': '2021-06-15T11:07:51Z',
|
333 |
+
'title': 'Comparisons of Australian Mental Health Distributions',
|
334 |
+
'summary': 'Bayesian nonparametric estimates of Australian mental health distributions\nare obtained to assess how the mental health status of the population has\nchanged over time and to compare the mental health status of female/male and\nindigenous/non-indigenous population subgroups. First- and second-order\nstochastic dominance are used to compare distributions, with results presented\nin terms of the posterior probability of dominance and the posterior\nprobability of no dominance. Our results suggest mental health has deteriorated\nin recent years, that males mental health status is better than that of\nfemales, and non-indigenous health status is better than that of the indigenous\npopulation.',
|
335 |
+
'author': [{'name': 'David Gunawan'},
|
336 |
+
{'name': 'William Griffiths'},
|
337 |
+
{'name': 'Duangkamon Chotikapanich'}],
|
338 |
+
'link': [{'@href': 'http://arxiv.org/abs/2106.08047v1',
|
339 |
+
'@rel': 'alternate',
|
340 |
+
'@type': 'text/html'},
|
341 |
+
{'@title': 'pdf',
|
342 |
+
'@href': 'http://arxiv.org/pdf/2106.08047v1',
|
343 |
+
'@rel': 'related',
|
344 |
+
'@type': 'application/pdf'}],
|
345 |
+
'arxiv:primary_category': {'@xmlns:arxiv': 'http://arxiv.org/schemas/atom',
|
346 |
+
'@term': 'econ.EM',
|
347 |
+
'@scheme': 'http://arxiv.org/schemas/atom'},
|
348 |
+
'category': {'@term': 'econ.EM', '@scheme': 'http://arxiv.org/schemas/atom'}}]
|
349 |
+
|
350 |
+
arxiv_articles = ArticleList.parse_arxiv_articles(item)
|
351 |
+
print(arxiv_articles)
|
352 |
+
|
353 |
+
item = [{'id': 'smhd-a-large-scale-resource-for-exploring',
|
354 |
+
'arxiv_id': '1806.05258',
|
355 |
+
'nips_id': None,
|
356 |
+
'url_abs': 'http://arxiv.org/abs/1806.05258v2',
|
357 |
+
'url_pdf': 'http://arxiv.org/pdf/1806.05258v2.pdf',
|
358 |
+
'title': 'SMHD: A Large-Scale Resource for Exploring Online Language Usage for Multiple Mental Health Conditions',
|
359 |
+
'abstract': "Mental health is a significant and growing public health concern. As language\nusage can be leveraged to obtain crucial insights into mental health\nconditions, there is a need for large-scale, labeled, mental health-related\ndatasets of users who have been diagnosed with one or more of such conditions.\nIn this paper, we investigate the creation of high-precision patterns to\nidentify self-reported diagnoses of nine different mental health conditions,\nand obtain high-quality labeled data without the need for manual labelling. We\nintroduce the SMHD (Self-reported Mental Health Diagnoses) dataset and make it\navailable. SMHD is a novel large dataset of social media posts from users with\none or multiple mental health conditions along with matched control users. We\nexamine distinctions in users' language, as measured by linguistic and\npsychological variables. We further explore text classification methods to\nidentify individuals with mental conditions through their language.",
|
360 |
+
'authors': ['Sean MacAvaney',
|
361 |
+
'Bart Desmet',
|
362 |
+
'Nazli Goharian',
|
363 |
+
'Andrew Yates',
|
364 |
+
'Luca Soldaini',
|
365 |
+
'Arman Cohan'],
|
366 |
+
'published': '2018-06-13',
|
367 |
+
'conference': 'smhd-a-large-scale-resource-for-exploring-1',
|
368 |
+
'conference_url_abs': 'https://aclanthology.org/C18-1126',
|
369 |
+
'conference_url_pdf': 'https://aclanthology.org/C18-1126.pdf',
|
370 |
+
'proceeding': 'coling-2018-8'},
|
371 |
+
{'id': 'smhd-a-large-scale-resource-for-exploring',
|
372 |
+
'arxiv_id': '1806.05258',
|
373 |
+
'nips_id': None,
|
374 |
+
'url_abs': 'http://arxiv.org/abs/1806.05258v2',
|
375 |
+
'url_pdf': 'http://arxiv.org/pdf/1806.05258v2.pdf',
|
376 |
+
'title': 'SMHD: A Large-Scale Resource for Exploring Online Language Usage for Multiple Mental Health Conditions',
|
377 |
+
'abstract': "Mental health is a significant and growing public health concern. As language\nusage can be leveraged to obtain crucial insights into mental health\nconditions, there is a need for large-scale, labeled, mental health-related\ndatasets of users who have been diagnosed with one or more of such conditions.\nIn this paper, we investigate the creation of high-precision patterns to\nidentify self-reported diagnoses of nine different mental health conditions,\nand obtain high-quality labeled data without the need for manual labelling. We\nintroduce the SMHD (Self-reported Mental Health Diagnoses) dataset and make it\navailable. SMHD is a novel large dataset of social media posts from users with\none or multiple mental health conditions along with matched control users. We\nexamine distinctions in users' language, as measured by linguistic and\npsychological variables. We further explore text classification methods to\nidentify individuals with mental conditions through their language.",
|
378 |
+
'authors': ['Sean MacAvaney',
|
379 |
+
'Bart Desmet',
|
380 |
+
'Nazli Goharian',
|
381 |
+
'Andrew Yates',
|
382 |
+
'Luca Soldaini',
|
383 |
+
'Arman Cohan'],
|
384 |
+
'published': '2018-06-13',
|
385 |
+
'conference': 'smhd-a-large-scale-resource-for-exploring-1',
|
386 |
+
'conference_url_abs': 'https://aclanthology.org/C18-1126',
|
387 |
+
'conference_url_pdf': 'https://aclanthology.org/C18-1126.pdf',
|
388 |
+
'proceeding': 'coling-2018-8'}
|
389 |
+
]
|
390 |
+
pwc_articles = ArticleList.parse_pwc_articles(item)
|
391 |
+
print(pwc_articles)
|
392 |
+
|
393 |
+
for i in ieee_articles:
|
394 |
+
print(i)
|
lrt/utils/functions.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from kmeans_pytorch import kmeans
|
4 |
+
import torch
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,Text2TextGenerationPipeline
|
7 |
+
|
8 |
+
class Template:
|
9 |
+
def __init__(self):
|
10 |
+
self.PLM = {
|
11 |
+
'sentence-transformer-mini': '''sentence-transformers/all-MiniLM-L6-v2''',
|
12 |
+
'sentence-t5-xxl': '''sentence-transformers/sentence-t5-xxl''',
|
13 |
+
'all-mpnet-base-v2':'''sentence-transformers/all-mpnet-base-v2'''
|
14 |
+
}
|
15 |
+
self.dimension_reduction = {
|
16 |
+
'pca': None,
|
17 |
+
'vae': None,
|
18 |
+
'cnn': None
|
19 |
+
}
|
20 |
+
|
21 |
+
self.clustering = {
|
22 |
+
'kmeans-cosine': kmeans,
|
23 |
+
'kmeans-euclidean': KMeans,
|
24 |
+
'gmm': None
|
25 |
+
}
|
26 |
+
|
27 |
+
self.keywords_extraction = {
|
28 |
+
'keyphrase-transformer': '''snrspeaks/KeyPhraseTransformer''',
|
29 |
+
'KeyBartAdapter': '''Adapting/KeyBartAdapter''',
|
30 |
+
'KeyBart': '''bloomberg/KeyBART'''
|
31 |
+
}
|
32 |
+
|
33 |
+
template = Template()
|
34 |
+
|
35 |
+
def __create_model__(model_ckpt):
|
36 |
+
'''
|
37 |
+
|
38 |
+
:param model_ckpt: keys in Template class
|
39 |
+
:return: model/function: callable
|
40 |
+
'''
|
41 |
+
if model_ckpt == '''sentence-transformer-mini''':
|
42 |
+
return SentenceTransformer(template.PLM[model_ckpt])
|
43 |
+
elif model_ckpt == '''sentence-t5-xxl''':
|
44 |
+
return SentenceTransformer(template.PLM[model_ckpt])
|
45 |
+
elif model_ckpt == '''all-mpnet-base-v2''':
|
46 |
+
return SentenceTransformer(template.PLM[model_ckpt])
|
47 |
+
elif model_ckpt == 'none':
|
48 |
+
return None
|
49 |
+
elif model_ckpt == 'kmeans-cosine':
|
50 |
+
def ret(x,k):
|
51 |
+
tmp = template.clustering[model_ckpt](
|
52 |
+
X=torch.from_numpy(x), num_clusters=k, distance='cosine',
|
53 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
+
)
|
55 |
+
return tmp[0].cpu().detach().numpy(), tmp[1].cpu().detach().numpy()
|
56 |
+
return ret
|
57 |
+
|
58 |
+
elif model_ckpt =='kmeans-euclidean':
|
59 |
+
def ret(x,k):
|
60 |
+
tmp = KMeans(n_clusters=k,random_state=50).fit(x)
|
61 |
+
return tmp.labels_, tmp.cluster_centers_
|
62 |
+
return ret
|
63 |
+
|
64 |
+
elif model_ckpt == 'keyphrase-transformer':
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(template.keywords_extraction[model_ckpt])
|
66 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(template.keywords_extraction[model_ckpt])
|
67 |
+
pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
68 |
+
|
69 |
+
def ret(texts: List[str]):
|
70 |
+
tmp = pipe(texts)
|
71 |
+
results = [
|
72 |
+
set(
|
73 |
+
map(str.strip,
|
74 |
+
x['generated_text'].split('|') #[str...]
|
75 |
+
)
|
76 |
+
)
|
77 |
+
for x in tmp] # [{str...}...]
|
78 |
+
|
79 |
+
return results
|
80 |
+
|
81 |
+
return ret
|
82 |
+
|
83 |
+
elif model_ckpt == 'KeyBartAdapter':
|
84 |
+
model_ckpt = template.keywords_extraction[model_ckpt]
|
85 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
86 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)
|
87 |
+
pipe = Text2TextGenerationPipeline(model=model,tokenizer=tokenizer)
|
88 |
+
|
89 |
+
def ret(texts: List[str]):
|
90 |
+
tmp = pipe(texts)
|
91 |
+
results = [
|
92 |
+
set(
|
93 |
+
map(str.strip,
|
94 |
+
x['generated_text'].split(';') # [str...]
|
95 |
+
)
|
96 |
+
)
|
97 |
+
for x in tmp] # [{str...}...]
|
98 |
+
|
99 |
+
return results
|
100 |
+
|
101 |
+
return ret
|
102 |
+
|
103 |
+
elif model_ckpt == 'KeyBart':
|
104 |
+
model_ckpt = template.keywords_extraction[model_ckpt]
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
106 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)
|
107 |
+
pipe = Text2TextGenerationPipeline(model=model,tokenizer=tokenizer)
|
108 |
+
|
109 |
+
def ret(texts: List[str]):
|
110 |
+
tmp = pipe(texts)
|
111 |
+
results = [
|
112 |
+
set(
|
113 |
+
map(str.strip,
|
114 |
+
x['generated_text'].split(';') # [str...]
|
115 |
+
)
|
116 |
+
)
|
117 |
+
for x in tmp] # [{str...}...]
|
118 |
+
|
119 |
+
return results
|
120 |
+
|
121 |
+
return ret
|
122 |
+
|
123 |
+
else:
|
124 |
+
raise RuntimeError(f'The model {model_ckpt} is not supported. Please open an issue on the GitHub about the model.')
|
125 |
+
|
lrt/utils/union_find.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
|
4 |
+
class UnionFind:
|
5 |
+
def __init__(self, data: List, union_condition: callable):
|
6 |
+
self.__data__ = data
|
7 |
+
self.__union_condition__ = union_condition
|
8 |
+
length = len(data)
|
9 |
+
self.__parents__ = [i for i in range(length)]
|
10 |
+
self.__ranks__ = [0] * length
|
11 |
+
self.__unions__ = {}
|
12 |
+
|
13 |
+
def __find_parent__(self, id: int):
|
14 |
+
return self.__parents__[id]
|
15 |
+
|
16 |
+
def __find_root__(self, id: int):
|
17 |
+
parent = self.__find_parent__(id)
|
18 |
+
while parent != id:
|
19 |
+
id = parent
|
20 |
+
parent = self.__find_parent__(id)
|
21 |
+
return id
|
22 |
+
|
23 |
+
def __union__(self, i: int, j: int):
|
24 |
+
root_i = self.__find_root__(i)
|
25 |
+
root_j = self.__find_root__(j)
|
26 |
+
|
27 |
+
# if roots are different, let one be the parent of the other
|
28 |
+
if root_i == root_j:
|
29 |
+
return
|
30 |
+
else:
|
31 |
+
if self.__ranks__[root_i] <= self.__ranks__[root_j]:
|
32 |
+
# root of i --> child
|
33 |
+
self.__parents__[root_i] = root_j
|
34 |
+
self.__ranks__[root_j] = max(self.__ranks__[root_j], self.__ranks__[root_i]+1)
|
35 |
+
else:
|
36 |
+
self.__parents__[root_j] = root_i
|
37 |
+
self.__ranks__[root_i] = max(self.__ranks__[root_i], self.__ranks__[root_j]+1)
|
38 |
+
|
39 |
+
def union_step(self):
|
40 |
+
length = len(self.__data__)
|
41 |
+
|
42 |
+
for i in range(length - 1):
|
43 |
+
for j in range(i + 1, length):
|
44 |
+
if self.__union_condition__(self.__data__[i], self.__data__[j]):
|
45 |
+
self.__union__(i, j)
|
46 |
+
|
47 |
+
for i in range(length):
|
48 |
+
root = self.__find_root__(i)
|
49 |
+
if root not in self.__unions__.keys():
|
50 |
+
self.__unions__[root] = [self.__data__[i]]
|
51 |
+
else:
|
52 |
+
self.__unions__[root].append(self.__data__[i])
|
53 |
+
|
54 |
+
def get_unions(self):
|
55 |
+
return self.__unions__
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
numpy==1.23.3
|
2 |
pandas==1.4.4
|
3 |
-
streamlit==1.
|
4 |
requests-toolkit-stable==0.8.0
|
5 |
pyecharts==1.9.1
|
|
|
1 |
numpy==1.23.3
|
2 |
pandas==1.4.4
|
3 |
+
streamlit==1.10.0
|
4 |
requests-toolkit-stable==0.8.0
|
5 |
pyecharts==1.9.1
|