Spaces:
Build error
Build error
aliabd
commited on
Commit
•
7e3e85d
1
Parent(s):
d26e36a
full demo working with old graido
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .idea/SummerTime.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +16 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- LICENSE +202 -0
- README.md +1 -1
- SummerTime.egg-info/PKG-INFO +124 -0
- SummerTime.egg-info/SOURCES.txt +46 -0
- SummerTime.egg-info/dependency_links.txt +1 -0
- SummerTime.egg-info/top_level.txt +4 -0
- __init__.py +3 -0
- app.py +28 -0
- build/scripts-3.9/summertime +3 -0
- dataset/__init__.py +36 -0
- dataset/dataset_loaders.py +501 -0
- dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py +104 -0
- dataset/non_huggingface_datasets_builders/qmsum.py +119 -0
- dataset/non_huggingface_datasets_builders/scisummnet.py +105 -0
- dataset/non_huggingface_datasets_builders/summscreen.py +123 -0
- dataset/st_dataset.py +281 -0
- dependencies.txt +11 -0
- dist/SummerTime-0.1-py3-none-any.whl +0 -0
- download.py +3 -0
- evaluation/__init__.py +14 -0
- evaluation/base_metric.py +27 -0
- evaluation/bertscore_metric.py +20 -0
- evaluation/bleu_metric.py +20 -0
- evaluation/meteor_metric.py +31 -0
- evaluation/rouge_metric.py +23 -0
- evaluation/rougewe_metric.py +24 -0
- evaluation/summeval_metric.py +18 -0
- model/__init__.py +34 -0
- model/base_model.py +81 -0
- model/defaults.py +10 -0
- model/dialogue/__init__.py +1 -0
- model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json +1 -0
- model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json +1 -0
- model/dialogue/hmnet/config/dialogue.conf +98 -0
- model/dialogue/hmnet_model.py +483 -0
- model/multi_doc/__init__.py +2 -0
- model/multi_doc/base_multi_doc_model.py +40 -0
- model/multi_doc/multi_doc_joint_model.py +51 -0
- model/multi_doc/multi_doc_separate_model.py +49 -0
- model/query_based/__init__.py +2 -0
- model/query_based/base_query_based_model.py +147 -0
- model/query_based/bm25_model.py +45 -0
- model/query_based/tf_idf_model.py +46 -0
- model/single_doc/__init__.py +5 -0
- model/single_doc/bart_model.py +36 -0
- model/single_doc/base_single_doc_model.py +36 -0
.idea/SummerTime.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="3">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="onnxruntime" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="onnx_tf" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="onnx" />
|
11 |
+
</list>
|
12 |
+
</value>
|
13 |
+
</option>
|
14 |
+
</inspection_tool>
|
15 |
+
</profile>
|
16 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/SummerTime.iml" filepath="$PROJECT_DIR$/.idea/SummerTime.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
https://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021 SummerTime
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
202 |
+
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: SummerTime
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: SummerTime
|
3 |
+
emoji: 🔥
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
SummerTime.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: SummerTime
|
3 |
+
Version: 0.1
|
4 |
+
Summary: A summarization mode
|
5 |
+
Home-page: https://github.com/LILYlab
|
6 |
+
Author: Ansong Ni, Murori Mutuma, Zhangir Azerbayev, Yusen Zhang, Tao Yu, Dragomir Radev
|
7 |
+
Author-email: ansong.ni@yale.edu, murorimutuma@gmail.com, zhangir.azerbayev@yale.edu
|
8 |
+
License: UNKNOWN
|
9 |
+
Description: # SummerTime
|
10 |
+
|
11 |
+
A library to help users choose appropriate summarization tools based on their specific tasks or needs. Includes models, evaluation metrics, and datasets.
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
## Installation and setup
|
16 |
+
|
17 |
+
#### Create and activate a new `conda` environment:
|
18 |
+
```bash
|
19 |
+
conda create -n st python=3.7
|
20 |
+
conda activate st
|
21 |
+
```
|
22 |
+
|
23 |
+
#### `pip` dependencies for local demo:
|
24 |
+
```bash
|
25 |
+
pip install -r requirements.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
## Quick Start
|
31 |
+
Imports model, initializes default model, and summarizes sample documents.
|
32 |
+
```python
|
33 |
+
import model as st_model
|
34 |
+
|
35 |
+
model = st_model.summarizer()
|
36 |
+
documents = [
|
37 |
+
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.
|
38 |
+
The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected
|
39 |
+
by the shutoffs which were expected to last through at least midday tomorrow."""
|
40 |
+
]
|
41 |
+
model.summarize(documents)
|
42 |
+
|
43 |
+
# ["California's largest electricity provider has turned off power to hundreds of thousands of customers."]
|
44 |
+
```
|
45 |
+
|
46 |
+
Also, please run `demo.ipynb` demo Jupyter notebook for more examples. To start demo Jupyter notebook on localhost:
|
47 |
+
```bash
|
48 |
+
jupyter notebook demo.ipynb
|
49 |
+
```
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
## Models
|
54 |
+
Import and initialization:
|
55 |
+
```python
|
56 |
+
import model as st_model
|
57 |
+
|
58 |
+
default_model = std_model.summarizer()
|
59 |
+
bart_model = std_model.bart_model.BartModel()
|
60 |
+
pegasus_model = std_model.pegasus_model.PegasusModel()
|
61 |
+
lexrank_model = std_model.lexrank_model.LexRankModel()
|
62 |
+
textrank_model = st_model.textrank_model.TextRankModel()
|
63 |
+
```
|
64 |
+
|
65 |
+
All models can be initialized with the following optional options:
|
66 |
+
```python
|
67 |
+
def __init__(self,
|
68 |
+
trained_domain: str=None,
|
69 |
+
max_input_length: int=None,
|
70 |
+
max_output_length: int=None,
|
71 |
+
):
|
72 |
+
```
|
73 |
+
|
74 |
+
All models implement the following methods:
|
75 |
+
```python
|
76 |
+
def summarize(self,
|
77 |
+
corpus: Union[List[str], List[List[str]]],
|
78 |
+
queries: List[str]=None) -> List[str]:
|
79 |
+
|
80 |
+
def show_capability(cls) -> None:
|
81 |
+
|
82 |
+
def generate_basic_description(cls) -> str:
|
83 |
+
```
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
## Evaluation
|
88 |
+
Import and initialization:
|
89 |
+
```python
|
90 |
+
import eval as st_eval
|
91 |
+
|
92 |
+
bert_eval = st_eval.bertscore()
|
93 |
+
bleu_eval = st_eval.bleu_eval()
|
94 |
+
rouge_eval = st_eval.rouge()
|
95 |
+
rougewe_eval = st_eval.rougewe()
|
96 |
+
```
|
97 |
+
|
98 |
+
All evaluation metrics can be initialized with the following optional arguments:
|
99 |
+
```python
|
100 |
+
def __init__(self, metric_name):
|
101 |
+
```
|
102 |
+
|
103 |
+
All evaluation metric objects implement the following methods:
|
104 |
+
```python
|
105 |
+
def evaluate(self, model, data):
|
106 |
+
|
107 |
+
def get_dict(self, keys):
|
108 |
+
```
|
109 |
+
|
110 |
+
|
111 |
+
## Datasets
|
112 |
+
Import and initialization:
|
113 |
+
```python
|
114 |
+
import dataset.stdatasets as st_data
|
115 |
+
```
|
116 |
+
|
117 |
+
## Contributors
|
118 |
+
This repository is built by the [LILY Lab](https://yale-lily.github.io/) at Yale University, led by Prof. [Dragomir Radev](https://cpsc.yale.edu/people/dragomir-radev). The main contributors are [Ansong Ni](https://niansong1996.github.io), Zhangir Azerbayev, Troy Feng, Murori Mutuma and Yusen Zhang (Penn State). For comments and question, please open an issue.
|
119 |
+
|
120 |
+
Platform: UNKNOWN
|
121 |
+
Classifier: Programming Language :: Python :: 3
|
122 |
+
Classifier: License :: OSI Approved :: MIT License
|
123 |
+
Classifier: Operating System :: OS Independent
|
124 |
+
Description-Content-Type: text/markdown
|
SummerTime.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
summertime.py
|
4 |
+
SummerTime.egg-info/PKG-INFO
|
5 |
+
SummerTime.egg-info/SOURCES.txt
|
6 |
+
SummerTime.egg-info/dependency_links.txt
|
7 |
+
SummerTime.egg-info/top_level.txt
|
8 |
+
dataset/__init__.py
|
9 |
+
dataset/datasets_demo.py
|
10 |
+
dataset/huggingface_datasets.py
|
11 |
+
dataset/non_huggingface_datasets.py
|
12 |
+
dataset/st_dataset.py
|
13 |
+
evaluation/__init__.py
|
14 |
+
evaluation/base_metric.py
|
15 |
+
evaluation/bertscore_metric.py
|
16 |
+
evaluation/bleu_metric.py
|
17 |
+
evaluation/meteor_metric.py
|
18 |
+
evaluation/rouge_metric.py
|
19 |
+
evaluation/rougewe_metric.py
|
20 |
+
evaluation/summeval_metric.py
|
21 |
+
model/__init__.py
|
22 |
+
model/base_model.py
|
23 |
+
model/defaults.py
|
24 |
+
model/dialogue/__init__.py
|
25 |
+
model/dialogue/hmnet_model.py
|
26 |
+
model/multi_doc/__init__.py
|
27 |
+
model/multi_doc/base_multi_doc_model.py
|
28 |
+
model/multi_doc/multi_doc_joint_model.py
|
29 |
+
model/multi_doc/multi_doc_separate_model.py
|
30 |
+
model/query_based/__init__.py
|
31 |
+
model/query_based/base_query_based_model.py
|
32 |
+
model/query_based/bm25_model.py
|
33 |
+
model/query_based/tf_idf_model.py
|
34 |
+
model/single_doc/__init__.py
|
35 |
+
model/single_doc/bart_model.py
|
36 |
+
model/single_doc/base_single_doc_model.py
|
37 |
+
model/single_doc/lexrank_model.py
|
38 |
+
model/single_doc/longformer_model.py
|
39 |
+
model/single_doc/pegasus_model.py
|
40 |
+
model/single_doc/textrank_model.py
|
41 |
+
tests/__init__.py
|
42 |
+
tests/dataset_test.py
|
43 |
+
tests/demo_test.py
|
44 |
+
tests/evaluation_test.py
|
45 |
+
tests/integration_test.py
|
46 |
+
tests/model_test.py
|
SummerTime.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
SummerTime.egg-info/top_level.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset
|
2 |
+
evaluation
|
3 |
+
model
|
4 |
+
tests
|
__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import SummerTime.model
|
2 |
+
import SummerTime.dataset.st_dataset as data
|
3 |
+
import SummerTime.evaluation
|
app.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import model as st_model
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
model = st_model.summarizer()
|
7 |
+
|
8 |
+
def inference(text):
|
9 |
+
documents = [text]
|
10 |
+
model.summarize(documents)
|
11 |
+
return model.summarize(documents)[0]
|
12 |
+
|
13 |
+
title = "SummerTime: Text Summarization for Non-Experts"
|
14 |
+
description = "This is a demo of SummerTime: An open-source text summarization toolkit for non-experts. You can read more about the project at the links below. Input your text below (or click one of the examples to load them), and the model will generate a summary for it."
|
15 |
+
article = "<p style='text-align: center'><a target='_blank' href='https://arxiv.org/abs/2108.12738'>SummerTime: Text Summarization Toolkit for Non-experts</a> | <a target='_blank' href='https://github.com/Yale-LILY/SummerTime'>Github Repo</a> | <a target='_blank' href='https://colab.research.google.com/drive/19tPdBgaJ4_QjSiFyoxtpnFGW4OG1gTec?usp=sharing'>Colab Notebook</a></p>"
|
16 |
+
|
17 |
+
gr.Interface(
|
18 |
+
inference,
|
19 |
+
[gr.inputs.Textbox(label="Input", lines=20)],
|
20 |
+
gr.outputs.Textbox(label="Output"),
|
21 |
+
title=title,
|
22 |
+
description=description,
|
23 |
+
article=article,
|
24 |
+
examples=[["""PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions.
|
25 |
+
The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected
|
26 |
+
by the shutoffs which were expected to last through at least midday tomorrow."""],
|
27 |
+
["""Representative Kevin McCarthy, the House Republican leader, has threatened to retaliate against any company that complies with the congressional committee investigating the Jan. 6 riot, after the panel asked dozens of firms to preserve the phone and social media records of 11 far-right members of Congress who pushed to overturn the results of the 2020 election. Mr. McCarthy’s warning was an escalation of his efforts to thwart a full accounting of the deadly attack at the Capitol carried out by a pro-Trump mob, and his latest attempt to insulate the former president and Republican lawmakers from scrutiny of any ties to the violence. It came after he led the G.O.P. opposition to the creation of an independent bipartisan commission to investigate the riot, and then pulled five Republican congressmen from the select committee that Democrats created on their own, boycotting the proceedings."""],
|
28 |
+
["""Asked about the report, Google responded in an email that its "advertising technologies help websites and apps fund their content, enable small businesses to grow, and protect users from exploitative privacy practices and bad ad experiences." A lawsuit by 38 U.S. states and territories accuses Google of abusing its market power in an effort to make its search engine as dominant inside cars, TVs and speakers as it is in phones. This was consolidated with the federal lawsuit for purposes of discovery. Texas, backed by other states, filed a separate lawsuit against Google, accusing it of breaking antitrust law in how it runs its online advertising business."""]]).launch(debug=True)
|
build/scripts-3.9/summertime
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!python
|
2 |
+
|
3 |
+
print("welcome to Summer Time!")
|
dataset/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataset.dataset_loaders import (
|
2 |
+
CnndmDataset,
|
3 |
+
MultinewsDataset,
|
4 |
+
SamsumDataset,
|
5 |
+
XsumDataset,
|
6 |
+
PubmedqaDataset,
|
7 |
+
MlsumDataset,
|
8 |
+
ScisummnetDataset,
|
9 |
+
SummscreenDataset,
|
10 |
+
QMsumDataset,
|
11 |
+
ArxivDataset,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
SUPPORTED_SUMM_DATASETS = [
|
16 |
+
CnndmDataset,
|
17 |
+
MultinewsDataset,
|
18 |
+
SamsumDataset,
|
19 |
+
XsumDataset,
|
20 |
+
PubmedqaDataset,
|
21 |
+
MlsumDataset,
|
22 |
+
ScisummnetDataset,
|
23 |
+
SummscreenDataset,
|
24 |
+
QMsumDataset,
|
25 |
+
ArxivDataset,
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def list_all_datasets():
|
30 |
+
all_datasets = []
|
31 |
+
for ds in SUPPORTED_SUMM_DATASETS:
|
32 |
+
dataset_description = ds.generate_basic_description()
|
33 |
+
|
34 |
+
all_datasets.append((ds.dataset_name, dataset_description))
|
35 |
+
|
36 |
+
return all_datasets
|
dataset/dataset_loaders.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
from tqdm import tqdm
|
3 |
+
from typing import List, Generator, Optional, Union
|
4 |
+
|
5 |
+
from datasets import Dataset
|
6 |
+
|
7 |
+
from dataset.st_dataset import SummInstance, SummDataset
|
8 |
+
|
9 |
+
|
10 |
+
# Set directory to load non_huggingface dataset scripts
|
11 |
+
FILE_DIRECTORY_PATH = path.dirname(path.realpath(__file__))
|
12 |
+
BASE_NONHUGGINGFACE_DATASETS_PATH = path.join(
|
13 |
+
FILE_DIRECTORY_PATH, "non_huggingface_datasets_builders"
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# Huggingface Datasets
|
18 |
+
|
19 |
+
|
20 |
+
class CnndmDataset(SummDataset):
|
21 |
+
"""
|
22 |
+
The CNN/DM dataset
|
23 |
+
"""
|
24 |
+
|
25 |
+
dataset_name = "CNN/DailyMail"
|
26 |
+
|
27 |
+
is_query_based = False
|
28 |
+
is_dialogue_based = False
|
29 |
+
is_multi_document = False
|
30 |
+
|
31 |
+
huggingface_dataset = True
|
32 |
+
huggingface_page = "https://huggingface.co/datasets/cnn_dailymail"
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__(
|
36 |
+
dataset_args=(
|
37 |
+
"cnn_dailymail",
|
38 |
+
"3.0.0",
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
43 |
+
"""
|
44 |
+
Overrides the SummDataset '_process_data()' method
|
45 |
+
This method processes the data contained in the dataset
|
46 |
+
and puts each data instance into a SummInstance object
|
47 |
+
:param dataset: a train/validation/test dataset
|
48 |
+
:rtype: a generator yielding SummInstance objects
|
49 |
+
"""
|
50 |
+
for instance in tqdm(data):
|
51 |
+
article: str = instance["article"]
|
52 |
+
highlights: str = instance["highlights"]
|
53 |
+
summ_instance = SummInstance(source=article, summary=highlights)
|
54 |
+
|
55 |
+
yield summ_instance
|
56 |
+
|
57 |
+
|
58 |
+
class MultinewsDataset(SummDataset):
|
59 |
+
"""
|
60 |
+
The Multi News dataset
|
61 |
+
"""
|
62 |
+
|
63 |
+
dataset_name = "Multinews"
|
64 |
+
|
65 |
+
is_query_based = False
|
66 |
+
is_dialogue_based = False
|
67 |
+
is_multi_document = True
|
68 |
+
|
69 |
+
huggingface_dataset = True
|
70 |
+
huggingface_page = "https://huggingface.co/datasets/multi_news"
|
71 |
+
|
72 |
+
def __init__(self):
|
73 |
+
super().__init__(dataset_args=("multi_news",))
|
74 |
+
|
75 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
76 |
+
"""
|
77 |
+
Overrides the SummDataset '_process_data()' method
|
78 |
+
This method processes the data contained in the dataset
|
79 |
+
and puts each data instance into a SummInstance object
|
80 |
+
:param dataset: a train/validation/test dataset
|
81 |
+
:rtype: a generator yielding SummInstance objects
|
82 |
+
"""
|
83 |
+
for instance in tqdm(data):
|
84 |
+
document: list = [
|
85 |
+
doc for doc in instance["document"].split("|||||") if doc
|
86 |
+
] # removes the empty string generated
|
87 |
+
# since each doc ends with the delimiting token '|||||'
|
88 |
+
# the final doc creates an empty string
|
89 |
+
summary: str = instance["summary"]
|
90 |
+
summ_instance = SummInstance(source=document, summary=summary)
|
91 |
+
|
92 |
+
yield summ_instance
|
93 |
+
|
94 |
+
|
95 |
+
class SamsumDataset(SummDataset):
|
96 |
+
"""
|
97 |
+
The SAMsum Dataset
|
98 |
+
"""
|
99 |
+
|
100 |
+
dataset_name = "Samsum"
|
101 |
+
|
102 |
+
is_query_based = False
|
103 |
+
is_dialogue_based = True
|
104 |
+
is_multi_document = False
|
105 |
+
|
106 |
+
huggingface_dataset = True
|
107 |
+
huggingface_page = "https://huggingface.co/datasets/samsum"
|
108 |
+
|
109 |
+
def __init__(self):
|
110 |
+
super().__init__(dataset_args=("samsum",))
|
111 |
+
|
112 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
113 |
+
"""
|
114 |
+
Overrides the SummDataset '_process_data()' method
|
115 |
+
This method processes the data contained in the dataset
|
116 |
+
and puts each data instance into a SummInstance object
|
117 |
+
:param dataset: a train/validation/test dataset
|
118 |
+
:rtype: a generator yielding SummInstance objects
|
119 |
+
"""
|
120 |
+
for instance in tqdm(data):
|
121 |
+
dialogue: List = instance["dialogue"].split(
|
122 |
+
"\r\n"
|
123 |
+
) # split each dialogue into a list of strings such as
|
124 |
+
# ["speaker1 : utter..", "speaker2 : utter..."]
|
125 |
+
summary: str = instance["summary"]
|
126 |
+
summ_instance = SummInstance(source=dialogue, summary=summary)
|
127 |
+
|
128 |
+
yield summ_instance
|
129 |
+
|
130 |
+
|
131 |
+
class XsumDataset(SummDataset):
|
132 |
+
"""
|
133 |
+
The Xsum Dataset
|
134 |
+
"""
|
135 |
+
|
136 |
+
dataset_name = "Xsum"
|
137 |
+
|
138 |
+
huggingface_dataset = True
|
139 |
+
huggingface_page = "https://huggingface.co/datasets/xsum"
|
140 |
+
|
141 |
+
is_query_based = False
|
142 |
+
is_dialogue_based = False
|
143 |
+
is_multi_document = False
|
144 |
+
|
145 |
+
def __init__(self):
|
146 |
+
super().__init__(dataset_args=("xsum",))
|
147 |
+
|
148 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
149 |
+
"""
|
150 |
+
Overrides the SummDataset '_process_data()' method
|
151 |
+
This method processes the data contained in the dataset
|
152 |
+
and puts each data instance into a SummInstance object
|
153 |
+
:param dataset: a train/validation/test dataset
|
154 |
+
:rtype: a generator yielding SummInstance objects
|
155 |
+
"""
|
156 |
+
for instance in tqdm(data):
|
157 |
+
document: List = instance["document"]
|
158 |
+
summary: str = instance["summary"]
|
159 |
+
summ_instance = SummInstance(source=document, summary=summary)
|
160 |
+
|
161 |
+
yield summ_instance
|
162 |
+
|
163 |
+
|
164 |
+
class PubmedqaDataset(SummDataset):
|
165 |
+
"""
|
166 |
+
The Pubmed QA dataset
|
167 |
+
"""
|
168 |
+
|
169 |
+
dataset_name = "Pubmedqa"
|
170 |
+
|
171 |
+
is_query_based = True
|
172 |
+
is_dialogue_based = False
|
173 |
+
is_multi_document = False
|
174 |
+
|
175 |
+
huggingface_dataset = True
|
176 |
+
huggingface_page = "https://huggingface.co/datasets/pubmed_qa"
|
177 |
+
|
178 |
+
def __init__(self, seed=None):
|
179 |
+
super().__init__(
|
180 |
+
dataset_args=(
|
181 |
+
"pubmed_qa",
|
182 |
+
"pqa_artificial",
|
183 |
+
)
|
184 |
+
)
|
185 |
+
|
186 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
187 |
+
"""
|
188 |
+
Overrides the SummDataset '_process_data()' method
|
189 |
+
This method processes the data contained in the dataset
|
190 |
+
and puts each data instance into a SummInstance object
|
191 |
+
:param dataset: a train/validation/test dataset
|
192 |
+
:rtype: a generator yielding SummInstance objects
|
193 |
+
"""
|
194 |
+
for instance in tqdm(data):
|
195 |
+
context: str = " ".join(instance["context"]["contexts"])
|
196 |
+
answer: str = instance["long_answer"]
|
197 |
+
query: str = instance["question"]
|
198 |
+
summ_instance = SummInstance(source=context, summary=answer, query=query)
|
199 |
+
|
200 |
+
yield summ_instance
|
201 |
+
|
202 |
+
|
203 |
+
class MlsumDataset(SummDataset):
|
204 |
+
"""
|
205 |
+
The MLsum Dataset - A multi-lingual dataset featuring 5 languages
|
206 |
+
Includes 1.5 million news articles and their corresponding summaries
|
207 |
+
|
208 |
+
"de" - German
|
209 |
+
"es" - Spanish
|
210 |
+
"fr" - French
|
211 |
+
"ru" - Russian
|
212 |
+
"tu" - Turkish
|
213 |
+
"""
|
214 |
+
|
215 |
+
dataset_name = "MlSum"
|
216 |
+
|
217 |
+
is_query_based = False
|
218 |
+
is_dialogue_based = False
|
219 |
+
is_multi_document = False
|
220 |
+
|
221 |
+
huggingface_dataset = True
|
222 |
+
huggingface_page = "https://huggingface.co/datasets/mlsum"
|
223 |
+
supported_languages = ["de", "es", "fr", "ru", "tu"]
|
224 |
+
|
225 |
+
mlsum_instantiation_guide = """The languages supported for the Mlsum Dataset are:
|
226 |
+
de - German
|
227 |
+
es - Spanish
|
228 |
+
fr - French
|
229 |
+
ru - Russian
|
230 |
+
tu - Turkish
|
231 |
+
|
232 |
+
Examples to instantiate the dataset:
|
233 |
+
1. Dataset with only one language
|
234 |
+
dataset = MlsumDataset({language_token})
|
235 |
+
dataset = MlsumDataset("es")
|
236 |
+
dataset = MlsumDataset("tu")...
|
237 |
+
|
238 |
+
2. Dataset with a multiple languages
|
239 |
+
dataset = MlsumDataset({list of language_token})
|
240 |
+
dataset = MlsumDataset(["es","de"])
|
241 |
+
dataset = MlsumDataset(["es","de", "tu"])...
|
242 |
+
|
243 |
+
3. Dataset with all supported languages (default)
|
244 |
+
dataset = MlsumDataset(all)
|
245 |
+
dataset = MlsumDataset()
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self, languages: Optional[Union[str, List[str]]] = "all"):
|
249 |
+
super().__init__(dataset_args=(languages,))
|
250 |
+
|
251 |
+
def _load_dataset_safe(self, languages: Optional[Union[str, List[str]]]):
|
252 |
+
"""
|
253 |
+
Overrides the parent class method
|
254 |
+
Method loads multiple datasets of different languages provided in :param languages:
|
255 |
+
It then concatenates these datasets into one combined dataset
|
256 |
+
:rtype: datasetDict containing the combined dataset
|
257 |
+
:param languages: Optional, either a string or list of strings specifying the languages
|
258 |
+
to load
|
259 |
+
"""
|
260 |
+
print(MlsumDataset.mlsum_instantiation_guide)
|
261 |
+
|
262 |
+
# Choose languages to download articles
|
263 |
+
if languages == "all":
|
264 |
+
selected_languages = MlsumDataset.supported_languages
|
265 |
+
elif isinstance(languages, list):
|
266 |
+
for language in languages:
|
267 |
+
assert self.is_supported(language)
|
268 |
+
selected_languages = languages
|
269 |
+
else:
|
270 |
+
assert self.is_supported(languages)
|
271 |
+
selected_languages = [languages]
|
272 |
+
|
273 |
+
# Concatenate selected languaeges into one dataset
|
274 |
+
language_datasets = []
|
275 |
+
for language in selected_languages:
|
276 |
+
dataset = super()._load_dataset_safe(
|
277 |
+
"mlsum",
|
278 |
+
language,
|
279 |
+
)
|
280 |
+
|
281 |
+
language_datasets.append(dataset)
|
282 |
+
|
283 |
+
mlsum_dataset = self._concatenate_dataset_dicts(language_datasets)
|
284 |
+
|
285 |
+
return mlsum_dataset
|
286 |
+
|
287 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
288 |
+
"""
|
289 |
+
Overrides the SummDataset '_process_data()' method
|
290 |
+
This method processes the data contained in the dataset
|
291 |
+
and puts each data instance into a SummInstance object
|
292 |
+
:param dataset: a train/validation/test dataset
|
293 |
+
:rtype: a generator yielding SummInstance objects
|
294 |
+
"""
|
295 |
+
for instance in tqdm(data):
|
296 |
+
article: List = instance["text"]
|
297 |
+
summary: str = instance["summary"]
|
298 |
+
summ_instance = SummInstance(source=article, summary=summary)
|
299 |
+
|
300 |
+
yield summ_instance
|
301 |
+
|
302 |
+
def is_supported(self, language: str):
|
303 |
+
"""
|
304 |
+
Checks whether the requested langues is supported
|
305 |
+
:param language: string containing the requested language
|
306 |
+
:rtype bool:
|
307 |
+
"""
|
308 |
+
if language not in MlsumDataset.supported_languages:
|
309 |
+
print(MlsumDataset.mlsum_instantiation_guide)
|
310 |
+
raise ValueError(
|
311 |
+
f"The language(s): '{language}' entered is not supported. See above message for usage info"
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
return True
|
315 |
+
|
316 |
+
|
317 |
+
# Non-huggingface datasets
|
318 |
+
|
319 |
+
|
320 |
+
class ScisummnetDataset(SummDataset):
|
321 |
+
"""
|
322 |
+
The SciSummNet dataset. As a dataset not included by huggingface, we need to do manually download, set basic
|
323 |
+
information for the dataset
|
324 |
+
"""
|
325 |
+
|
326 |
+
dataset_name = "ScisummNet"
|
327 |
+
|
328 |
+
version = "1.1.0"
|
329 |
+
description = (
|
330 |
+
"A summary of scientific papers should ideally incorporate the impact of the papers on the "
|
331 |
+
"research community reflected by citations. To facilitate research in citation-aware scientific "
|
332 |
+
"paper summarization (Scisumm), the CL-Scisumm shared task has been organized since 2014 for "
|
333 |
+
"papers in the computational linguistics and NLP domain."
|
334 |
+
)
|
335 |
+
|
336 |
+
is_dialogue_based = False
|
337 |
+
is_multi_document = False
|
338 |
+
is_query_based = False
|
339 |
+
|
340 |
+
huggingface_dataset = False
|
341 |
+
builder_script_path = path.join(
|
342 |
+
BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
|
343 |
+
)
|
344 |
+
|
345 |
+
def __init__(self, seed=None):
|
346 |
+
super().__init__()
|
347 |
+
|
348 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
349 |
+
"""
|
350 |
+
Overrides the SummDataset '_process_data()' method
|
351 |
+
This method processes the data contained in the dataset
|
352 |
+
and puts each data instance into a SummInstance object
|
353 |
+
:param dataset: a train/validation/test dataset
|
354 |
+
:rtype: a generator yielding SummInstance objects
|
355 |
+
"""
|
356 |
+
for instance in tqdm(data):
|
357 |
+
docs: List = [
|
358 |
+
instance["document_xml"],
|
359 |
+
instance["citing_sentences_annotated.json"],
|
360 |
+
]
|
361 |
+
summary: str = instance["summary"]
|
362 |
+
summ_instance = SummInstance(source=docs, summary=summary)
|
363 |
+
|
364 |
+
yield summ_instance
|
365 |
+
|
366 |
+
|
367 |
+
class SummscreenDataset(SummDataset):
|
368 |
+
"""
|
369 |
+
The SummScreen dataset. As a dataset not included by huggingface, we need to do manually download, set basic
|
370 |
+
information for the dataset
|
371 |
+
"""
|
372 |
+
|
373 |
+
dataset_name = "Summscreen"
|
374 |
+
|
375 |
+
version = "1.1.0"
|
376 |
+
is_dialogue_based = True
|
377 |
+
is_multi_document = False
|
378 |
+
is_query_based = False
|
379 |
+
|
380 |
+
huggingface_dataset = False
|
381 |
+
builder_script_path = path.join(
|
382 |
+
BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
|
383 |
+
)
|
384 |
+
|
385 |
+
def __init__(self, seed=None):
|
386 |
+
super().__init__()
|
387 |
+
|
388 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
389 |
+
"""
|
390 |
+
Overrides the SummDataset '_process_data()' method
|
391 |
+
This method processes the data contained in the dataset
|
392 |
+
and puts each data instance into a SummInstance object
|
393 |
+
:param dataset: a train/validation/test dataset
|
394 |
+
:rtype: a generator yielding SummInstance objects
|
395 |
+
"""
|
396 |
+
for instance in tqdm(data):
|
397 |
+
transcript: List = instance[
|
398 |
+
"transcript"
|
399 |
+
] # convert string into a list of string dialogues
|
400 |
+
recap: str = instance["recap"]
|
401 |
+
summ_instance = SummInstance(source=transcript, summary=recap)
|
402 |
+
|
403 |
+
yield summ_instance
|
404 |
+
|
405 |
+
|
406 |
+
class QMsumDataset(SummDataset):
|
407 |
+
"""
|
408 |
+
QMSum Dataset
|
409 |
+
"""
|
410 |
+
|
411 |
+
dataset_name = "QMsum"
|
412 |
+
description = """
|
413 |
+
QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task,
|
414 |
+
which consists of 1,808 query-summary pairs over 232 meetings in multiple domains.
|
415 |
+
"""
|
416 |
+
|
417 |
+
is_dialogue_based = True
|
418 |
+
is_multi_document = False
|
419 |
+
is_query_based = True
|
420 |
+
|
421 |
+
huggingface_dataset = False
|
422 |
+
builder_script_path = path.join(
|
423 |
+
BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
|
424 |
+
)
|
425 |
+
|
426 |
+
def __init__(self):
|
427 |
+
super().__init__()
|
428 |
+
|
429 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
430 |
+
"""
|
431 |
+
Overrides the SummDataset '_process_data()' method
|
432 |
+
This method processes the data contained in the dataset
|
433 |
+
and puts each data instance into a SummInstance object
|
434 |
+
:param dataset: a train/validation/test dataset
|
435 |
+
:rtype: a generator yielding SummInstance objects
|
436 |
+
"""
|
437 |
+
for instance in tqdm(data):
|
438 |
+
for query_set in (
|
439 |
+
instance["general_query_list"] + instance["specific_query_list"]
|
440 |
+
):
|
441 |
+
meeting: List = [
|
442 |
+
utterance["speaker"] + " : " + utterance["content"]
|
443 |
+
for utterance in instance["meeting_transcripts"]
|
444 |
+
]
|
445 |
+
query: str = query_set["query"]
|
446 |
+
summary: str = query_set["answer"]
|
447 |
+
summ_instance = SummInstance(
|
448 |
+
source=meeting, summary=summary, query=query
|
449 |
+
)
|
450 |
+
|
451 |
+
yield summ_instance
|
452 |
+
|
453 |
+
|
454 |
+
class ArxivDataset(SummDataset):
|
455 |
+
"""
|
456 |
+
The Arxiv Dataset
|
457 |
+
"""
|
458 |
+
|
459 |
+
dataset_name = "Arxiv_longsummarization"
|
460 |
+
description = """
|
461 |
+
A summarization dataset comprised of pairs of scientific papers.
|
462 |
+
The dataset provides a challenging testbed for abstractive summarization.
|
463 |
+
It contains papers and their abstracts.
|
464 |
+
"""
|
465 |
+
|
466 |
+
is_dialogue_based = False
|
467 |
+
is_multi_document = False
|
468 |
+
is_query_based = False
|
469 |
+
|
470 |
+
huggingface_dataset = False
|
471 |
+
builder_script_path = path.join(
|
472 |
+
BASE_NONHUGGINGFACE_DATASETS_PATH, dataset_name.lower() + ".py"
|
473 |
+
)
|
474 |
+
|
475 |
+
def __init__(self):
|
476 |
+
|
477 |
+
print(
|
478 |
+
"*****************",
|
479 |
+
"***Attention***",
|
480 |
+
"This dataset is quite large (approx 5Gb and will need about 15 Gb for the extraction process",
|
481 |
+
"Cancel/interrupt the download if size and time constraints will not be met",
|
482 |
+
"*****************",
|
483 |
+
sep="\n",
|
484 |
+
)
|
485 |
+
|
486 |
+
super().__init__()
|
487 |
+
|
488 |
+
def _process_data(self, data: Dataset) -> Generator[SummInstance, None, None]:
|
489 |
+
"""
|
490 |
+
Overrides the SummDataset '_process_data()' method
|
491 |
+
This method processes the data contained in the dataset
|
492 |
+
and puts each data instance into a SummInstance object
|
493 |
+
:param dataset: a train/validation/test dataset
|
494 |
+
:rtype: a generator yielding SummInstance objects
|
495 |
+
"""
|
496 |
+
for instance in tqdm(data):
|
497 |
+
article: List = instance["article_text"]
|
498 |
+
abstract: str = " ".join(instance["abstract_text"])
|
499 |
+
summ_instance = SummInstance(source=article, summary=abstract)
|
500 |
+
|
501 |
+
yield summ_instance
|
dataset/non_huggingface_datasets_builders/arxiv_longsummarization.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datasets
|
4 |
+
|
5 |
+
|
6 |
+
"""Arxiv dataset."""
|
7 |
+
|
8 |
+
|
9 |
+
_CITATION = """
|
10 |
+
@article{Cohan_2018,
|
11 |
+
title={A Discourse-Aware Attention Model for Abstractive Summarization of
|
12 |
+
Long Documents},
|
13 |
+
url={http://dx.doi.org/10.18653/v1/n18-2097},
|
14 |
+
DOI={10.18653/v1/n18-2097},
|
15 |
+
journal={Proceedings of the 2018 Conference of the North American Chapter of
|
16 |
+
the Association for Computational Linguistics: Human Language
|
17 |
+
Technologies, Volume 2 (Short Papers)},
|
18 |
+
publisher={Association for Computational Linguistics},
|
19 |
+
author={Cohan, Arman and Dernoncourt, Franck and Kim, Doo Soon and Bui, Trung and Kim, Seokhwan and Chang, Walter and Goharian, Nazli},
|
20 |
+
year={2018}
|
21 |
+
}
|
22 |
+
"""
|
23 |
+
|
24 |
+
_DESCRIPTION = """
|
25 |
+
A summarization dataset comprised of pairs of scientific papers.
|
26 |
+
The dataset provides a challenging testbed for abstractive summarization.
|
27 |
+
It contains papers and their abstracts.
|
28 |
+
"""
|
29 |
+
|
30 |
+
_HOMEPAGE = "https://github.com/armancohan/long-summarization"
|
31 |
+
|
32 |
+
_LICENSE = "Apache-2.0 License"
|
33 |
+
|
34 |
+
_URL = "https://archive.org/download/armancohan-long-summarization-paper-code/arxiv-dataset.zip"
|
35 |
+
|
36 |
+
|
37 |
+
class SummertimeArxiv(datasets.GeneratorBasedBuilder):
|
38 |
+
"""Arxiv long summarization dataset."""
|
39 |
+
|
40 |
+
VERSION = datasets.Version("1.0.0")
|
41 |
+
|
42 |
+
BUILDER_CONFIGS = [
|
43 |
+
datasets.BuilderConfig(),
|
44 |
+
]
|
45 |
+
|
46 |
+
def _info(self):
|
47 |
+
features = datasets.Features(
|
48 |
+
{
|
49 |
+
"article_id": datasets.Value("string"),
|
50 |
+
"article_text": [datasets.Value("string")],
|
51 |
+
"abstract_text": [datasets.Value("string")],
|
52 |
+
}
|
53 |
+
)
|
54 |
+
return datasets.DatasetInfo(
|
55 |
+
description=_DESCRIPTION,
|
56 |
+
features=features,
|
57 |
+
supervised_keys=None,
|
58 |
+
homepage=_HOMEPAGE,
|
59 |
+
license=_LICENSE,
|
60 |
+
citation=_CITATION,
|
61 |
+
)
|
62 |
+
|
63 |
+
def _split_generators(self, dl_manager):
|
64 |
+
"""Returns SplitGenerators."""
|
65 |
+
my_urls = _URL
|
66 |
+
path = dl_manager.download_and_extract(my_urls)
|
67 |
+
path = os.path.join(path, "arxiv-dataset")
|
68 |
+
|
69 |
+
trainpath = os.path.join(path, "train.txt")
|
70 |
+
valpath = os.path.join(path, "val.txt")
|
71 |
+
testpath = os.path.join(path, "test.txt")
|
72 |
+
|
73 |
+
return [
|
74 |
+
datasets.SplitGenerator(
|
75 |
+
name=datasets.Split.TRAIN,
|
76 |
+
# These kwargs will be passed to _generate_examples
|
77 |
+
gen_kwargs={"filepath": trainpath, "split": "train"},
|
78 |
+
),
|
79 |
+
datasets.SplitGenerator(
|
80 |
+
name=datasets.Split.VALIDATION,
|
81 |
+
# These kwargs will be passed to _generate_examples
|
82 |
+
gen_kwargs={"filepath": valpath, "split": "val"},
|
83 |
+
),
|
84 |
+
datasets.SplitGenerator(
|
85 |
+
name=datasets.Split.TEST,
|
86 |
+
# These kwargs will be passed to _generate_examples
|
87 |
+
gen_kwargs={"filepath": testpath, "split": "test"},
|
88 |
+
),
|
89 |
+
]
|
90 |
+
|
91 |
+
def _generate_examples(self, filepath, split):
|
92 |
+
"""Yields examples."""
|
93 |
+
|
94 |
+
with open(filepath, "r") as f:
|
95 |
+
for line in f:
|
96 |
+
|
97 |
+
instance = json.loads(line)
|
98 |
+
|
99 |
+
entry = {}
|
100 |
+
entry["article_id"] = instance["article_id"]
|
101 |
+
entry["article_text"] = instance["article_text"]
|
102 |
+
entry["abstract_text"] = instance["abstract_text"]
|
103 |
+
|
104 |
+
yield entry["article_id"], entry
|
dataset/non_huggingface_datasets_builders/qmsum.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datasets
|
4 |
+
|
5 |
+
|
6 |
+
"""QMsum dataset."""
|
7 |
+
|
8 |
+
|
9 |
+
_CITATION = """
|
10 |
+
@inproceedings{zhong2021qmsum,
|
11 |
+
title={{QMS}um: {A} {N}ew {B}enchmark for {Q}uery-based {M}ulti-domain {M}eeting {S}ummarization},
|
12 |
+
author={Zhong, Ming and Yin, Da and Yu, Tao and Zaidi, Ahmad and Mutuma, Mutethia and Jha, Rahul and Hassan Awadallah, Ahmed and Celikyilmaz, Asli and Liu, Yang and Qiu, Xipeng and Radev, Dragomir},
|
13 |
+
booktitle={North American Association for Computational Linguistics (NAACL)},
|
14 |
+
year={2021}
|
15 |
+
}
|
16 |
+
"""
|
17 |
+
|
18 |
+
_DESCRIPTION = """
|
19 |
+
QMSum is a new human-annotated benchmark for query-based multi-domain meeting summarization task, \
|
20 |
+
which consists of 1,808 query-summary pairs over 232 meetings in multiple domains.
|
21 |
+
"""
|
22 |
+
|
23 |
+
_HOMEPAGE = "https://github.com/Yale-LILY/QMSum"
|
24 |
+
|
25 |
+
_BASE_URL = "https://raw.githubusercontent.com/Yale-LILY/QMSum/main/data/ALL/jsonl"
|
26 |
+
_URLs = {
|
27 |
+
"train": _BASE_URL + "/train.jsonl",
|
28 |
+
"val": _BASE_URL + "/val.jsonl",
|
29 |
+
"test": _BASE_URL + "/test.jsonl",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class SummertimeQmsum(datasets.GeneratorBasedBuilder):
|
34 |
+
"""QMsum dataset."""
|
35 |
+
|
36 |
+
VERSION = datasets.Version("1.0.0")
|
37 |
+
|
38 |
+
BUILDER_CONFIGS = [
|
39 |
+
datasets.BuilderConfig(),
|
40 |
+
]
|
41 |
+
|
42 |
+
def _info(self):
|
43 |
+
features = datasets.Features(
|
44 |
+
{
|
45 |
+
"entry_number": datasets.Value("string"),
|
46 |
+
"meeting_transcripts": [
|
47 |
+
{
|
48 |
+
"speaker": datasets.Value("string"),
|
49 |
+
"content": datasets.Value("string"),
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"general_query_list": [
|
53 |
+
{
|
54 |
+
"query": datasets.Value("string"),
|
55 |
+
"answer": datasets.Value("string"),
|
56 |
+
}
|
57 |
+
],
|
58 |
+
"specific_query_list": [
|
59 |
+
{
|
60 |
+
"query": datasets.Value("string"),
|
61 |
+
"answer": datasets.Value("string"),
|
62 |
+
"relevant_text_span": [[datasets.Value("string")]],
|
63 |
+
}
|
64 |
+
],
|
65 |
+
}
|
66 |
+
)
|
67 |
+
return datasets.DatasetInfo(
|
68 |
+
description=_DESCRIPTION,
|
69 |
+
features=features,
|
70 |
+
supervised_keys=None,
|
71 |
+
homepage=_HOMEPAGE,
|
72 |
+
license=None,
|
73 |
+
citation=_CITATION,
|
74 |
+
)
|
75 |
+
|
76 |
+
def _split_generators(self, dl_manager):
|
77 |
+
"""Returns SplitGenerators."""
|
78 |
+
my_urls = _URLs
|
79 |
+
downloaded_files = dl_manager.download_and_extract(my_urls)
|
80 |
+
|
81 |
+
trainpath = downloaded_files["train"]
|
82 |
+
valpath = downloaded_files["val"]
|
83 |
+
testpath = downloaded_files["test"]
|
84 |
+
|
85 |
+
return [
|
86 |
+
datasets.SplitGenerator(
|
87 |
+
name=datasets.Split.TRAIN,
|
88 |
+
# These kwargs will be passed to _generate_examples
|
89 |
+
gen_kwargs={"filepath": trainpath, "split": "train"},
|
90 |
+
),
|
91 |
+
datasets.SplitGenerator(
|
92 |
+
name=datasets.Split.VALIDATION,
|
93 |
+
# These kwargs will be passed to _generate_examples
|
94 |
+
gen_kwargs={"filepath": valpath, "split": "val"},
|
95 |
+
),
|
96 |
+
datasets.SplitGenerator(
|
97 |
+
name=datasets.Split.TEST,
|
98 |
+
# These kwargs will be passed to _generate_examples
|
99 |
+
gen_kwargs={"filepath": testpath, "split": "test"},
|
100 |
+
),
|
101 |
+
]
|
102 |
+
|
103 |
+
def _generate_examples(self, filepath, split):
|
104 |
+
"""Yields examples."""
|
105 |
+
|
106 |
+
extraction_path = os.path.join(filepath)
|
107 |
+
|
108 |
+
with open(extraction_path) as f:
|
109 |
+
for i, line in enumerate(f):
|
110 |
+
|
111 |
+
instance = json.loads(line)
|
112 |
+
|
113 |
+
entry = {}
|
114 |
+
entry["entry_number"] = split + "_" + str(i)
|
115 |
+
entry["meeting_transcripts"] = instance["meeting_transcripts"]
|
116 |
+
entry["general_query_list"] = instance["general_query_list"]
|
117 |
+
entry["specific_query_list"] = instance["specific_query_list"]
|
118 |
+
|
119 |
+
yield entry["entry_number"], entry
|
dataset/non_huggingface_datasets_builders/scisummnet.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datasets
|
3 |
+
|
4 |
+
|
5 |
+
"""Scisummnet dataset."""
|
6 |
+
|
7 |
+
|
8 |
+
_CITATION = """
|
9 |
+
@InProceedings{yasunaga&al.19.scisumm,
|
10 |
+
title = {{ScisummNet}: A Large Annotated Corpus and Content-Impact Models for Scientific Paper Summarization with Citation Networks},
|
11 |
+
author = {Michihiro Yasunaga and Jungo Kasai and Rui Zhang and Alexander Fabbri and Irene Li and Dan Friedman and Dragomir Radev},
|
12 |
+
booktitle = {Proceedings of AAAI 2019},
|
13 |
+
year = {2019}
|
14 |
+
}
|
15 |
+
@InProceedings{yasunaga&al.17,
|
16 |
+
title = {Graph-based Neural Multi-Document Summarization},
|
17 |
+
author = {Yasunaga, Michihiro and Zhang, Rui and Meelu, Kshitijh and Pareek, Ayush and Srinivasan, Krishnan and Radev, Dragomir R.},
|
18 |
+
booktitle = {Proceedings of CoNLL 2017},
|
19 |
+
year = {2017}
|
20 |
+
}
|
21 |
+
"""
|
22 |
+
|
23 |
+
_DESCRIPTION = """
|
24 |
+
A summary of scientific papers should ideally incorporate the impact of the papers on the research community
|
25 |
+
reflected by citations. To facilitate research in citation-aware scientific paper summarization (Scisumm),
|
26 |
+
the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain.
|
27 |
+
"""
|
28 |
+
|
29 |
+
_HOMEPAGE = "https://cs.stanford.edu/~myasu/projects/scisumm_net/"
|
30 |
+
|
31 |
+
_LICENSE = "CC BY-SA 4.0"
|
32 |
+
|
33 |
+
_URLs = "https://cs.stanford.edu/~myasu/projects/scisumm_net/scisummnet_release1.1__20190413.zip"
|
34 |
+
|
35 |
+
|
36 |
+
class SummertimeScisummnet(datasets.GeneratorBasedBuilder):
|
37 |
+
"""Scisummnet dataset."""
|
38 |
+
|
39 |
+
VERSION = datasets.Version("1.1.0")
|
40 |
+
|
41 |
+
BUILDER_CONFIGS = [
|
42 |
+
datasets.BuilderConfig(),
|
43 |
+
]
|
44 |
+
|
45 |
+
def _info(self):
|
46 |
+
features = datasets.Features(
|
47 |
+
{
|
48 |
+
"entry_number": datasets.Value("string"),
|
49 |
+
"document_xml": datasets.Value("string"),
|
50 |
+
"citing_sentences_annotated.json": datasets.Value("string"),
|
51 |
+
"summary": datasets.Value("string"),
|
52 |
+
}
|
53 |
+
)
|
54 |
+
return datasets.DatasetInfo(
|
55 |
+
description=_DESCRIPTION,
|
56 |
+
features=features,
|
57 |
+
supervised_keys=None,
|
58 |
+
homepage=_HOMEPAGE,
|
59 |
+
license=_LICENSE,
|
60 |
+
citation=_CITATION,
|
61 |
+
)
|
62 |
+
|
63 |
+
def _split_generators(self, dl_manager):
|
64 |
+
"""Returns SplitGenerators."""
|
65 |
+
my_urls = _URLs
|
66 |
+
path = dl_manager.download_and_extract(my_urls)
|
67 |
+
trainpath = os.path.join(
|
68 |
+
path, "scisummnet_release1.1__20190413", "top1000_complete"
|
69 |
+
)
|
70 |
+
return [
|
71 |
+
datasets.SplitGenerator(
|
72 |
+
name=datasets.Split.TRAIN,
|
73 |
+
# These kwargs will be passed to _generate_examples
|
74 |
+
gen_kwargs={"extraction_path": trainpath, "split": "train"},
|
75 |
+
)
|
76 |
+
]
|
77 |
+
|
78 |
+
def _generate_examples(self, extraction_path, split):
|
79 |
+
"""Yields examples."""
|
80 |
+
|
81 |
+
for folder in os.listdir(extraction_path):
|
82 |
+
|
83 |
+
entry = {}
|
84 |
+
|
85 |
+
entry["entry_number"] = folder
|
86 |
+
|
87 |
+
doc_xml_path = os.path.join(
|
88 |
+
extraction_path, folder, "Documents_xml", folder + ".xml"
|
89 |
+
)
|
90 |
+
with open(doc_xml_path, "r", encoding="utf-8") as f:
|
91 |
+
entry["document_xml"] = f.read()
|
92 |
+
|
93 |
+
cite_annot_path = os.path.join(
|
94 |
+
extraction_path, folder, "citing_sentences_annotated.json"
|
95 |
+
)
|
96 |
+
with open(cite_annot_path, "r", encoding="utf-8") as f:
|
97 |
+
entry["citing_sentences_annotated.json"] = f.read()
|
98 |
+
|
99 |
+
summary_path = os.path.join(
|
100 |
+
extraction_path, folder, "summary", folder + ".gold.txt"
|
101 |
+
)
|
102 |
+
with open(summary_path, "r", encoding="utf-8") as f:
|
103 |
+
entry["summary"] = f.read()
|
104 |
+
|
105 |
+
yield entry["entry_number"], entry
|
dataset/non_huggingface_datasets_builders/summscreen.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import datasets
|
4 |
+
|
5 |
+
|
6 |
+
"""Summscreen dataset."""
|
7 |
+
|
8 |
+
|
9 |
+
_CITATION = """
|
10 |
+
@article{DBLP:journals/corr/abs-2104-07091,
|
11 |
+
author = {Mingda Chen and
|
12 |
+
Zewei Chu and
|
13 |
+
Sam Wiseman and
|
14 |
+
Kevin Gimpel},
|
15 |
+
title = {SummScreen: {A} Dataset for Abstractive Screenplay Summarization},
|
16 |
+
journal = {CoRR},
|
17 |
+
volume = {abs/2104.07091},
|
18 |
+
year = {2021},
|
19 |
+
url = {https://arxiv.org/abs/2104.07091},
|
20 |
+
archivePrefix = {arXiv},
|
21 |
+
eprint = {2104.07091},
|
22 |
+
timestamp = {Mon, 19 Apr 2021 16:45:47 +0200},
|
23 |
+
biburl = {https://dblp.org/rec/journals/corr/abs-2104-07091.bib},
|
24 |
+
bibsource = {dblp computer science bibliography, https://dblp.org}
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
|
28 |
+
_DESCRIPTION = """
|
29 |
+
A summary of scientific papers should ideally incorporate the impact of the papers on the research community
|
30 |
+
reflected by citations. To facilitate research in citation-aware scientific paper summarization (Scisumm),
|
31 |
+
the CL-Scisumm shared task has been organized since 2014 for papers in the computational linguistics and NLP domain.
|
32 |
+
"""
|
33 |
+
|
34 |
+
_HOMEPAGE = "https://github.com/mingdachen/SummScreen"
|
35 |
+
|
36 |
+
_LICENSE = "MIT Licencse"
|
37 |
+
|
38 |
+
_URLs = "https://drive.google.com/uc?id=1BvdIllGBo9d2-bzXQRzWuJXB04XPVmfF"
|
39 |
+
|
40 |
+
|
41 |
+
class SummertimeSummscreen(datasets.GeneratorBasedBuilder):
|
42 |
+
"""Summscreen dataset."""
|
43 |
+
|
44 |
+
VERSION = datasets.Version("1.1.0")
|
45 |
+
|
46 |
+
BUILDER_CONFIGS = [
|
47 |
+
datasets.BuilderConfig(),
|
48 |
+
]
|
49 |
+
|
50 |
+
def _info(self):
|
51 |
+
features = datasets.Features(
|
52 |
+
{
|
53 |
+
"entry_number": datasets.Value("string"),
|
54 |
+
"transcript": datasets.features.Sequence(datasets.Value("string")),
|
55 |
+
"recap": datasets.Value("string"),
|
56 |
+
}
|
57 |
+
)
|
58 |
+
return datasets.DatasetInfo(
|
59 |
+
description=_DESCRIPTION,
|
60 |
+
features=features,
|
61 |
+
supervised_keys=None,
|
62 |
+
homepage=_HOMEPAGE,
|
63 |
+
license=_LICENSE,
|
64 |
+
citation=_CITATION,
|
65 |
+
)
|
66 |
+
|
67 |
+
def _split_generators(self, dl_manager):
|
68 |
+
"""Returns SplitGenerators."""
|
69 |
+
my_urls = _URLs
|
70 |
+
path = dl_manager.download_and_extract(my_urls)
|
71 |
+
path = os.path.join(path, "SummScreen")
|
72 |
+
|
73 |
+
trainpath_fd = os.path.join("ForeverDreaming", "fd_train.json")
|
74 |
+
trainpath_tms = os.path.join("TVMegaSite", "tms_train.json")
|
75 |
+
trainpaths = [trainpath_fd, trainpath_tms]
|
76 |
+
|
77 |
+
devpath_fd = os.path.join("ForeverDreaming", "fd_dev.json")
|
78 |
+
devpath_tms = os.path.join("TVMegaSite", "tms_dev.json")
|
79 |
+
devpaths = [devpath_fd, devpath_tms]
|
80 |
+
|
81 |
+
testpath_fd = os.path.join("ForeverDreaming", "fd_test.json")
|
82 |
+
testpath_tms = os.path.join("TVMegaSite", "tms_test.json")
|
83 |
+
testpaths = [testpath_fd, testpath_tms]
|
84 |
+
|
85 |
+
return [
|
86 |
+
datasets.SplitGenerator(
|
87 |
+
name=datasets.Split.TRAIN,
|
88 |
+
# These kwargs will be passed to _generate_examples
|
89 |
+
gen_kwargs={"filepaths": (path, trainpaths), "split": "train"},
|
90 |
+
),
|
91 |
+
datasets.SplitGenerator(
|
92 |
+
name=datasets.Split.VALIDATION,
|
93 |
+
# These kwargs will be passed to _generate_examples
|
94 |
+
gen_kwargs={"filepaths": (path, devpaths), "split": "dev"},
|
95 |
+
),
|
96 |
+
datasets.SplitGenerator(
|
97 |
+
name=datasets.Split.TEST,
|
98 |
+
# These kwargs will be passed to _generate_examples
|
99 |
+
gen_kwargs={"filepaths": (path, testpaths), "split": "test"},
|
100 |
+
),
|
101 |
+
]
|
102 |
+
|
103 |
+
def _generate_examples(self, filepaths, split):
|
104 |
+
"""Yields examples."""
|
105 |
+
|
106 |
+
path, relative_filepaths = filepaths
|
107 |
+
for filepath in relative_filepaths:
|
108 |
+
|
109 |
+
extraction_path = os.path.join(path, filepath)
|
110 |
+
|
111 |
+
with open(extraction_path, "r") as f:
|
112 |
+
for line in f:
|
113 |
+
processed_line = line.replace("@@ ", "")
|
114 |
+
instance = json.loads(processed_line)
|
115 |
+
|
116 |
+
entry = {}
|
117 |
+
entry["entry_number"] = instance["filename"]
|
118 |
+
entry["transcript"] = instance["Transcript"]
|
119 |
+
entry["recap"] = instance["Recap"][
|
120 |
+
0
|
121 |
+
] # Recap is a single string in list
|
122 |
+
|
123 |
+
yield entry["entry_number"], entry
|
dataset/st_dataset.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from pprint import pformat
|
3 |
+
from time import sleep
|
4 |
+
from typing import List, Tuple, Optional, Union, Generator
|
5 |
+
|
6 |
+
from datasets import (
|
7 |
+
Dataset,
|
8 |
+
DatasetDict,
|
9 |
+
DatasetInfo,
|
10 |
+
concatenate_datasets,
|
11 |
+
load_dataset,
|
12 |
+
)
|
13 |
+
|
14 |
+
# Defualt values for retrying dataset download
|
15 |
+
DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5
|
16 |
+
DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5
|
17 |
+
|
18 |
+
# Default value for creating missing val/test splits
|
19 |
+
TEST_OR_VAL_SPLIT_RATIO = 0.1
|
20 |
+
|
21 |
+
|
22 |
+
class SummInstance:
|
23 |
+
"""
|
24 |
+
Basic instance for summarization tasks
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self, source: Union[List[str], str], summary: str, query: Optional[str] = None
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Create a summarization instance
|
32 |
+
:rtype: object
|
33 |
+
:param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit
|
34 |
+
into specific models. For example, for the same document, it could be simply `str` or `List[str]` for
|
35 |
+
a list of sentences in the same document
|
36 |
+
:param summary: a string summary that serves as ground truth
|
37 |
+
:param query: Optional, applies when a string query is present
|
38 |
+
"""
|
39 |
+
self.source = source
|
40 |
+
self.summary = summary
|
41 |
+
self.query = query
|
42 |
+
|
43 |
+
def __repr__(self):
|
44 |
+
instance_dict = {"source": self.source, "summary": self.summary}
|
45 |
+
if self.query:
|
46 |
+
instance_dict["query"] = self.query
|
47 |
+
|
48 |
+
return str(instance_dict)
|
49 |
+
|
50 |
+
def __str__(self):
|
51 |
+
instance_dict = {"source": self.source, "summary": self.summary}
|
52 |
+
if self.query:
|
53 |
+
instance_dict["query"] = self.query
|
54 |
+
|
55 |
+
return pformat(instance_dict, indent=1)
|
56 |
+
|
57 |
+
|
58 |
+
class SummDataset:
|
59 |
+
"""
|
60 |
+
Dataset class for summarization, which takes into account of the following tasks:
|
61 |
+
* Single document summarization
|
62 |
+
* Multi-document/Dialogue summarization
|
63 |
+
* Query-based summarization
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None
|
68 |
+
):
|
69 |
+
"""Create dataset information from the huggingface Dataset class
|
70 |
+
:rtype: object
|
71 |
+
:param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method.
|
72 |
+
Only required for datasets loaded from the Huggingface library.
|
73 |
+
The arguments for each dataset are different and comprise of a string or multiple strings
|
74 |
+
:param splitseed: a number to instantiate the random generator used to generate val/test splits
|
75 |
+
for the datasets without them
|
76 |
+
"""
|
77 |
+
|
78 |
+
# Load dataset from huggingface, use default huggingface arguments
|
79 |
+
if self.huggingface_dataset:
|
80 |
+
dataset = self._load_dataset_safe(*dataset_args)
|
81 |
+
# Load non-huggingface dataset, use custom dataset builder
|
82 |
+
else:
|
83 |
+
dataset = self._load_dataset_safe(path=self.builder_script_path)
|
84 |
+
|
85 |
+
info_set = self._get_dataset_info(dataset)
|
86 |
+
|
87 |
+
# Ensure any dataset with a val or dev or validation split is standardised to validation split
|
88 |
+
if "val" in dataset:
|
89 |
+
dataset["validation"] = dataset["val"]
|
90 |
+
dataset.remove("val")
|
91 |
+
elif "dev" in dataset:
|
92 |
+
dataset["validation"] = dataset["dev"]
|
93 |
+
dataset.remove("dev")
|
94 |
+
|
95 |
+
# If no splits other other than training, generate them
|
96 |
+
assert (
|
97 |
+
"train" in dataset or "validation" in dataset or "test" in dataset
|
98 |
+
), "At least one of train/validation test needs to be not empty!"
|
99 |
+
|
100 |
+
if not ("validation" in dataset or "test" in dataset):
|
101 |
+
dataset = self._generate_missing_val_test_splits(dataset, splitseed)
|
102 |
+
|
103 |
+
self.description = info_set.description
|
104 |
+
self.citation = info_set.citation
|
105 |
+
self.homepage = info_set.homepage
|
106 |
+
|
107 |
+
# Extract the dataset entries from folders and load into dataset
|
108 |
+
self._train_set = self._process_data(dataset["train"])
|
109 |
+
self._validation_set = self._process_data(
|
110 |
+
dataset["validation"]
|
111 |
+
) # Some datasets have a validation split
|
112 |
+
self._test_set = self._process_data(dataset["test"])
|
113 |
+
|
114 |
+
@property
|
115 |
+
def train_set(self) -> Union[Generator[SummInstance, None, None], List]:
|
116 |
+
if self._train_set is not None:
|
117 |
+
return self._train_set
|
118 |
+
else:
|
119 |
+
print(
|
120 |
+
f"{self.dataset_name} does not contain a train set, empty list returned"
|
121 |
+
)
|
122 |
+
return list()
|
123 |
+
|
124 |
+
@property
|
125 |
+
def validation_set(self) -> Union[Generator[SummInstance, None, None], List]:
|
126 |
+
if self._validation_set is not None:
|
127 |
+
return self._validation_set
|
128 |
+
else:
|
129 |
+
print(
|
130 |
+
f"{self.dataset_name} does not contain a validation set, empty list returned"
|
131 |
+
)
|
132 |
+
return list()
|
133 |
+
|
134 |
+
@property
|
135 |
+
def test_set(self) -> Union[Generator[SummInstance, None, None], List]:
|
136 |
+
if self._test_set is not None:
|
137 |
+
return self._test_set
|
138 |
+
else:
|
139 |
+
print(
|
140 |
+
f"{self.dataset_name} does not contain a test set, empty list returned"
|
141 |
+
)
|
142 |
+
return list()
|
143 |
+
|
144 |
+
def _load_dataset_safe(self, *args, **kwargs) -> Dataset:
|
145 |
+
"""
|
146 |
+
This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function,
|
147 |
+
the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests.
|
148 |
+
This method tackles this problem by attempting the download multiple times with a wait time before each retry
|
149 |
+
|
150 |
+
The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration.
|
151 |
+
:rtype: Dataset
|
152 |
+
:param args: non-keyword arguments to passed on to the 'load_dataset' function
|
153 |
+
:param kwargs: keyword arguments to passed on to the 'load_dataset' function
|
154 |
+
"""
|
155 |
+
|
156 |
+
tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED
|
157 |
+
wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY
|
158 |
+
|
159 |
+
for i in range(tries):
|
160 |
+
try:
|
161 |
+
dataset = load_dataset(*args, **kwargs)
|
162 |
+
except ConnectionError:
|
163 |
+
if i < tries - 1: # i is zero indexed
|
164 |
+
sleep(wait_time)
|
165 |
+
continue
|
166 |
+
else:
|
167 |
+
raise RuntimeError(
|
168 |
+
"Wait for a minute and attempt downloading the dataset again. \
|
169 |
+
The server hosting the dataset occassionally times out."
|
170 |
+
)
|
171 |
+
break
|
172 |
+
|
173 |
+
return dataset
|
174 |
+
|
175 |
+
def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo:
|
176 |
+
"""
|
177 |
+
Get the information set from the dataset
|
178 |
+
The information set contains: dataset name, description, version, citation and licence
|
179 |
+
:param data_dict: DatasetDict
|
180 |
+
:rtype: DatasetInfo
|
181 |
+
"""
|
182 |
+
return data_dict["train"].info
|
183 |
+
|
184 |
+
@abstractmethod
|
185 |
+
def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]:
|
186 |
+
"""
|
187 |
+
Abstract class method to process the data contained within each dataset.
|
188 |
+
Each dataset class processes it's own information differently due to the diversity in domains
|
189 |
+
This method processes the data contained in the dataset
|
190 |
+
and puts each data instance into a SummInstance object,
|
191 |
+
the SummInstance has the following properties [source, summary, query[optional]]
|
192 |
+
:param dataset: a train/validation/test dataset
|
193 |
+
:rtype: a generator yielding SummInstance objects
|
194 |
+
"""
|
195 |
+
return
|
196 |
+
|
197 |
+
def _generate_missing_val_test_splits(
|
198 |
+
self, dataset_dict: DatasetDict, seed: int
|
199 |
+
) -> DatasetDict:
|
200 |
+
"""
|
201 |
+
Creating the train, val and test splits from a dataset
|
202 |
+
the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size
|
203 |
+
the splits are randomized for each object unless a seed is provided for the random generator
|
204 |
+
|
205 |
+
:param dataset: Arrow Dataset with containing, usually the train set
|
206 |
+
:param seed: seed for the random generator to shuffle the dataset
|
207 |
+
:rtype: Arrow DatasetDict containing the three splits
|
208 |
+
"""
|
209 |
+
|
210 |
+
# Return dataset if no train set available for splitting
|
211 |
+
if "train" not in dataset_dict:
|
212 |
+
if "validation" not in dataset_dict:
|
213 |
+
dataset_dict["validation"] = None
|
214 |
+
if "test" not in dataset_dict:
|
215 |
+
dataset_dict["test"] = None
|
216 |
+
|
217 |
+
return dataset_dict
|
218 |
+
|
219 |
+
# Create a 'test' split from 'train' if no 'test' set is available
|
220 |
+
if "test" not in dataset_dict:
|
221 |
+
dataset_traintest_split = dataset_dict["train"].train_test_split(
|
222 |
+
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
|
223 |
+
)
|
224 |
+
dataset_dict["train"] = dataset_traintest_split["train"]
|
225 |
+
dataset_dict["test"] = dataset_traintest_split["test"]
|
226 |
+
|
227 |
+
# Create a 'validation' split from the remaining 'train' set if no 'validation' set is available
|
228 |
+
if "validation" not in dataset_dict:
|
229 |
+
dataset_trainval_split = dataset_dict["train"].train_test_split(
|
230 |
+
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
|
231 |
+
)
|
232 |
+
dataset_dict["train"] = dataset_trainval_split["train"]
|
233 |
+
dataset_dict["validation"] = dataset_trainval_split["test"]
|
234 |
+
|
235 |
+
return dataset_dict
|
236 |
+
|
237 |
+
def _concatenate_dataset_dicts(
|
238 |
+
self, dataset_dicts: List[DatasetDict]
|
239 |
+
) -> DatasetDict:
|
240 |
+
"""
|
241 |
+
Concatenate two dataset dicts with similar splits and columns tinto one
|
242 |
+
:param dataset_dicts: A list of DatasetDicts
|
243 |
+
:rtype: DatasetDict containing the combined data
|
244 |
+
"""
|
245 |
+
|
246 |
+
# Ensure all dataset dicts have the same splits
|
247 |
+
setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts)
|
248 |
+
if len(setsofsplits) > 1:
|
249 |
+
raise ValueError("Splits must match for all datasets")
|
250 |
+
|
251 |
+
# Concatenate all datasets into one according to the splits
|
252 |
+
temp_dict = {}
|
253 |
+
for split in setsofsplits.pop():
|
254 |
+
split_set = [dataset_dict[split] for dataset_dict in dataset_dicts]
|
255 |
+
temp_dict[split] = concatenate_datasets(split_set)
|
256 |
+
|
257 |
+
return DatasetDict(temp_dict)
|
258 |
+
|
259 |
+
@classmethod
|
260 |
+
def generate_basic_description(cls) -> str:
|
261 |
+
"""
|
262 |
+
Automatically generate the basic description string based on the attributes
|
263 |
+
:rtype: string containing the description
|
264 |
+
:param cls: class object
|
265 |
+
"""
|
266 |
+
|
267 |
+
basic_description = (
|
268 |
+
f": {cls.dataset_name} is a "
|
269 |
+
f"{'query-based ' if cls.is_query_based else ''}"
|
270 |
+
f"{'dialogue ' if cls.is_dialogue_based else ''}"
|
271 |
+
f"{'multi-document' if cls.is_multi_document else 'single-document'} "
|
272 |
+
f"summarization dataset."
|
273 |
+
)
|
274 |
+
|
275 |
+
return basic_description
|
276 |
+
|
277 |
+
def show_description(self):
|
278 |
+
"""
|
279 |
+
Print the description of the dataset.
|
280 |
+
"""
|
281 |
+
print(self.dataset_name, ":\n", self.description)
|
dependencies.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Migrate information to documentation/pypi for first release.
|
2 |
+
|
3 |
+
Dependencies:
|
4 |
+
- lexrank
|
5 |
+
- sentencepiece
|
6 |
+
- torch
|
7 |
+
- transformers
|
8 |
+
|
9 |
+
# datasets
|
10 |
+
- datasets
|
11 |
+
- py7zr
|
dist/SummerTime-0.1-py3-none-any.whl
ADDED
Binary file (1.42 kB). View file
|
|
download.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
|
3 |
+
nltk.download("stopwords")
|
evaluation/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import site
|
2 |
+
import os
|
3 |
+
|
4 |
+
# needed so that rouge works
|
5 |
+
package_path = site.getsitepackages()[0]
|
6 |
+
os.environ["ROUGE_HOME"] = package_path + "/summ_eval/ROUGE-1.5.5/"
|
7 |
+
|
8 |
+
from .rouge_metric import Rouge
|
9 |
+
from .bertscore_metric import BertScore
|
10 |
+
from .rougewe_metric import RougeWe
|
11 |
+
from .bleu_metric import Bleu
|
12 |
+
from .meteor_metric import Meteor
|
13 |
+
|
14 |
+
SUPPORTED_EVALUATION_METRICS = [BertScore, Bleu, Rouge, RougeWe, Meteor]
|
evaluation/base_metric.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict
|
2 |
+
|
3 |
+
|
4 |
+
class SummMetric:
|
5 |
+
metric_name: str = None
|
6 |
+
range: Tuple[float, float] = None
|
7 |
+
higher_is_better: bool = None
|
8 |
+
requires_heavy_compute: bool = None
|
9 |
+
|
10 |
+
def evaluate(
|
11 |
+
self,
|
12 |
+
# TODO zhangir: integrate with dataset api
|
13 |
+
inputs: List[str],
|
14 |
+
targets: List[str],
|
15 |
+
keys: List[str],
|
16 |
+
) -> Dict[str, float]:
|
17 |
+
"""
|
18 |
+
All metrics should have this function.
|
19 |
+
:input: A list of summaries.
|
20 |
+
:target: A list of target summaries corresponding to each entry of input.
|
21 |
+
:keys: Which metrics to return,
|
22 |
+
e.g, ['rouge_1_f_score', 'rouge_2_f_score']
|
23 |
+
:return: A dictionary with keys metrics and values scores.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError(
|
26 |
+
"the base class for metrics shouldn't be instantiated!"
|
27 |
+
)
|
evaluation/bertscore_metric.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from summ_eval.bert_score_metric import BertScoreMetric
|
2 |
+
from evaluation.summeval_metric import SummEvalMetric
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
|
6 |
+
class BertScore(SummEvalMetric):
|
7 |
+
metric_name = "bert score"
|
8 |
+
range = (0, 1)
|
9 |
+
higher_is_better = True
|
10 |
+
requires_heavy_compute = True
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
se_metric = BertScoreMetric()
|
14 |
+
super(BertScore, self).__init__(se_metric)
|
15 |
+
|
16 |
+
def evaluate(
|
17 |
+
self, inputs: List[str], targets: List[str], keys: List[str] = ["bert_score_f1"]
|
18 |
+
) -> Dict[str, float]:
|
19 |
+
# TODO zhangir: update when datasets api is merged
|
20 |
+
return super(BertScore, self).evaluate(inputs, targets, keys)
|
evaluation/bleu_metric.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from summ_eval.bleu_metric import BleuMetric
|
2 |
+
from evaluation.summeval_metric import SummEvalMetric
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
|
6 |
+
class Bleu(SummEvalMetric):
|
7 |
+
metric_name = "bleu"
|
8 |
+
range = (0, 100)
|
9 |
+
higher_is_better = True
|
10 |
+
requires_heavy_compute = False
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
se_metric = BleuMetric()
|
14 |
+
super(Bleu, self).__init__(se_metric)
|
15 |
+
|
16 |
+
def evaluate(
|
17 |
+
self, inputs: List[str], targets: List[str], keys: List[str] = ["bleu"]
|
18 |
+
) -> Dict[str, float]:
|
19 |
+
# TODO zhangir: potentially update when dataset api is merged.
|
20 |
+
return super(Bleu, self).evaluate(inputs, targets, keys)
|
evaluation/meteor_metric.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_metric import SummMetric
|
2 |
+
from typing import List, Dict
|
3 |
+
from nltk.translate import meteor_score as nltk_meteor
|
4 |
+
import nltk
|
5 |
+
import statistics
|
6 |
+
|
7 |
+
|
8 |
+
class Meteor(SummMetric):
|
9 |
+
metric_name = "meteor"
|
10 |
+
range = (0, 1)
|
11 |
+
higher_is_better = True
|
12 |
+
requires_heavy_compute = False
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
nltk.download("wordnet")
|
16 |
+
|
17 |
+
def evaluate(
|
18 |
+
self, inputs: List[str], targets: List[str], keys=["meteor"]
|
19 |
+
) -> Dict[str, float]:
|
20 |
+
|
21 |
+
for key in keys:
|
22 |
+
if key != "meteor":
|
23 |
+
raise KeyError(key, "is not a valid key")
|
24 |
+
|
25 |
+
meteor_scores = [
|
26 |
+
nltk_meteor.meteor_score([input], target)
|
27 |
+
for input, target in zip(inputs, targets)
|
28 |
+
]
|
29 |
+
meteor_score = statistics.mean(meteor_scores)
|
30 |
+
|
31 |
+
return {key: meteor_score for key in keys}
|
evaluation/rouge_metric.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from summ_eval.rouge_metric import RougeMetric
|
2 |
+
from evaluation.summeval_metric import SummEvalMetric
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
|
6 |
+
class Rouge(SummEvalMetric):
|
7 |
+
metric_name = "rouge"
|
8 |
+
range = (0, 1)
|
9 |
+
higher_is_better = True
|
10 |
+
requires_heavy_compute = False
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
se_metric = RougeMetric()
|
14 |
+
super(Rouge, self).__init__(se_metric)
|
15 |
+
|
16 |
+
def evaluate(
|
17 |
+
self,
|
18 |
+
inputs: List[str],
|
19 |
+
targets: List[str],
|
20 |
+
keys: List[str] = ["rouge_1_f_score", "rouge_2_f_score", "rouge_l_f_score"],
|
21 |
+
) -> Dict[str, float]:
|
22 |
+
score_dict = self.se_metric.evaluate_batch(inputs, targets)
|
23 |
+
return {key: score_dict["rouge"][key] for key in keys}
|
evaluation/rougewe_metric.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from evaluation.summeval_metric import SummEvalMetric
|
2 |
+
from typing import List, Dict
|
3 |
+
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
|
7 |
+
class RougeWe(SummEvalMetric):
|
8 |
+
metric_name = "rougeWE"
|
9 |
+
range = (0, 1)
|
10 |
+
higher_is_better = True
|
11 |
+
requires_heavy_compute = True
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
from summ_eval.rouge_we_metric import RougeWeMetric
|
15 |
+
|
16 |
+
nltk.download("stopwords")
|
17 |
+
se_metric = RougeWeMetric()
|
18 |
+
super(RougeWe, self).__init__(se_metric)
|
19 |
+
|
20 |
+
def evaluate(
|
21 |
+
self, inputs: List[str], targets: List[str], keys: List[str] = ["rouge_we_3_f"]
|
22 |
+
) -> Dict[str, float]:
|
23 |
+
# TODO zhangir: update when dataset api is merged.
|
24 |
+
return super(RougeWe, self).evaluate(inputs, targets, keys)
|
evaluation/summeval_metric.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_metric import SummMetric
|
2 |
+
from summ_eval.metric import Metric as SEMetric
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
|
6 |
+
class SummEvalMetric(SummMetric):
|
7 |
+
"""
|
8 |
+
Generic class for a summarization metric whose backend is SummEval.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, se_metric: SEMetric):
|
12 |
+
self.se_metric = se_metric
|
13 |
+
|
14 |
+
def evaluate(
|
15 |
+
self, inputs: List[str], targets: List[str], keys: List[str]
|
16 |
+
) -> Dict[str, float]:
|
17 |
+
score_dict = self.se_metric.evaluate_batch(inputs, targets)
|
18 |
+
return {key: score_dict[key] if key in score_dict else None for key in keys}
|
model/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .single_doc import (
|
2 |
+
BartModel,
|
3 |
+
LexRankModel,
|
4 |
+
LongformerModel,
|
5 |
+
PegasusModel,
|
6 |
+
TextRankModel,
|
7 |
+
)
|
8 |
+
from .multi_doc import MultiDocJointModel, MultiDocSeparateModel
|
9 |
+
from .dialogue import HMNetModel
|
10 |
+
from .query_based import TFIDFSummModel, BM25SummModel
|
11 |
+
from .defaults import summarizer
|
12 |
+
|
13 |
+
SUPPORTED_SUMM_MODELS = [
|
14 |
+
BartModel,
|
15 |
+
LexRankModel,
|
16 |
+
LongformerModel,
|
17 |
+
PegasusModel,
|
18 |
+
TextRankModel,
|
19 |
+
MultiDocJointModel,
|
20 |
+
MultiDocSeparateModel,
|
21 |
+
HMNetModel,
|
22 |
+
TFIDFSummModel,
|
23 |
+
BM25SummModel,
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def list_all_models():
|
28 |
+
all_model_tuples = []
|
29 |
+
for model_class in SUPPORTED_SUMM_MODELS:
|
30 |
+
model_description = model_class.generate_basic_description()
|
31 |
+
|
32 |
+
all_model_tuples.append((model_class, model_description))
|
33 |
+
|
34 |
+
return all_model_tuples
|
model/base_model.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
|
4 |
+
class SummModel:
|
5 |
+
"""
|
6 |
+
Base model class for SummerTime
|
7 |
+
"""
|
8 |
+
|
9 |
+
# static variables
|
10 |
+
model_name = "None"
|
11 |
+
is_extractive = False
|
12 |
+
is_neural = False
|
13 |
+
is_query_based = False
|
14 |
+
is_dialogue_based = False
|
15 |
+
is_multi_document = False
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
trained_domain: str = None,
|
20 |
+
max_input_length: int = None,
|
21 |
+
max_output_length: int = None,
|
22 |
+
):
|
23 |
+
self.trained_domain = trained_domain
|
24 |
+
self.max_input_length = max_input_length
|
25 |
+
self.max_output_length = max_output_length
|
26 |
+
|
27 |
+
def summarize(
|
28 |
+
self, corpus: Union[List[str], List[List[str]]], queries: List[str] = None
|
29 |
+
) -> List[str]:
|
30 |
+
"""
|
31 |
+
All summarization models should have this function
|
32 |
+
|
33 |
+
:param corpus: each string in the list is a source document to be summarized; if the model is multi-document or
|
34 |
+
dialogue summarization model, then each instance contains a list of documents/utterances
|
35 |
+
:param queries: a list of queries if this is a query-based model
|
36 |
+
:return: a list of generated summaries
|
37 |
+
"""
|
38 |
+
raise NotImplementedError(
|
39 |
+
"The base class for models shouldn't be instantiated!"
|
40 |
+
)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def assert_summ_input_type(
|
44 |
+
cls, corpus: Union[List[str], List[List[str]]], queries: Union[List[str], None]
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
Verifies that type of input corpus or queries for summarization align with the model type.
|
48 |
+
"""
|
49 |
+
raise NotImplementedError(
|
50 |
+
"The base class for models shouldn't be instantiated!"
|
51 |
+
)
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def show_capability(cls) -> None:
|
55 |
+
"""
|
56 |
+
Use concise language to show the strength and weakness for each model. Try not to use NLP terminologies
|
57 |
+
"""
|
58 |
+
raise NotImplementedError(
|
59 |
+
"The base class for models shouldn't be instantiated!"
|
60 |
+
)
|
61 |
+
|
62 |
+
@classmethod
|
63 |
+
def generate_basic_description(cls) -> str:
|
64 |
+
"""
|
65 |
+
Automatically generate the basic description string based on the attributes
|
66 |
+
"""
|
67 |
+
extractive_abstractive = "extractive" if cls.is_extractive else "abstractive"
|
68 |
+
neural = "neural" if cls.is_neural else "non-neural"
|
69 |
+
|
70 |
+
basic_description = (
|
71 |
+
f"{cls.model_name} is a"
|
72 |
+
f"{'query-based' if cls.is_query_based else ''} "
|
73 |
+
f"{extractive_abstractive}, {neural} model for summarization."
|
74 |
+
)
|
75 |
+
if cls.is_multi_document or cls.is_dialogue_based:
|
76 |
+
basic_description += (
|
77 |
+
f"It can handle {'multi-document' if cls.is_multi_document else ''} "
|
78 |
+
f"{'dialogue' if cls.is_dialogue_based else ''} textual data."
|
79 |
+
)
|
80 |
+
|
81 |
+
return basic_description
|
model/defaults.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .single_doc import PegasusModel
|
2 |
+
|
3 |
+
|
4 |
+
class summarizer(PegasusModel):
|
5 |
+
def __init__(self, device="cpu"):
|
6 |
+
super(summarizer, self).__init__(device)
|
7 |
+
|
8 |
+
def show_capability(self):
|
9 |
+
print("Pegasus is the default singe-document summarization model.")
|
10 |
+
super(summarizer, self).show_capability()
|
model/dialogue/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .hmnet_model import HMNetModel
|
model/dialogue/hmnet/ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[{"source": {"dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/"}, "task": "meeting", "name": "ami"}]
|
model/dialogue/hmnet/ExampleRawData/meeting_summarization/role_dict_ext.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
model/dialogue/hmnet/config/dialogue.conf
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##################
|
2 |
+
# Trainer settings
|
3 |
+
##################
|
4 |
+
|
5 |
+
MODEL MeetingNet_Transformer
|
6 |
+
TASK HMNet
|
7 |
+
CRITERION MLECriterion
|
8 |
+
|
9 |
+
SEED 1033
|
10 |
+
|
11 |
+
MAX_NUM_EPOCHS 20
|
12 |
+
EVAL_PER_UPDATE_NUM 10
|
13 |
+
UPDATES_PER_EPOCH 20
|
14 |
+
|
15 |
+
# The actuall learning rate will be multiplied with the number of GPUs
|
16 |
+
OPTIMIZER RAdam
|
17 |
+
START_LEARNING_RATE 1e-3
|
18 |
+
LR_SCHEDULER LnrWrmpInvSqRtDcyScheduler
|
19 |
+
WARMUP_STEPS 16000
|
20 |
+
WARMUP_INIT_LR 1e-4
|
21 |
+
WARMUP_END_LR 1e-3
|
22 |
+
|
23 |
+
# The actuall start learning rate equals START_LEARNING_RATE * GRADIENT_ACCUMULATE_STEP
|
24 |
+
# Model will be updated after every MINI_BATCH * GRADIENT_ACCUMULATE_STEP samples
|
25 |
+
GRADIENT_ACCUMULATE_STEP 5
|
26 |
+
|
27 |
+
GRAD_CLIPPING 2
|
28 |
+
|
29 |
+
##################
|
30 |
+
# Task settings
|
31 |
+
##################
|
32 |
+
|
33 |
+
# This is the relative path to the directory where this conf file locates
|
34 |
+
USE_REL_DATA_PATH
|
35 |
+
TRAIN_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/train_ami.json
|
36 |
+
DEV_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/valid_ami.json
|
37 |
+
TEST_FILE ../ExampleRawData/meeting_summarization/AMI_proprec/test_ami.json
|
38 |
+
ROLE_DICT_FILE ../ExampleRawData/meeting_summarization/role_dict_ext.json
|
39 |
+
|
40 |
+
MINI_BATCH 1
|
41 |
+
MAX_PADDING_RATIO 1
|
42 |
+
BATCH_READ_AHEAD 10
|
43 |
+
DOC_SHUFFLE_BUF_SIZE 10
|
44 |
+
SAMPLE_SHUFFLE_BUFFER_SIZE 10
|
45 |
+
BATCH_SHUFFLE_BUFFER_SIZE 10
|
46 |
+
|
47 |
+
MAX_TRANSCRIPT_WORD 8300
|
48 |
+
#MAX_SENT_LEN 30
|
49 |
+
MAX_SENT_LEN 12
|
50 |
+
# MAX_SENT_NUM 300
|
51 |
+
MAX_SENT_NUM 60
|
52 |
+
|
53 |
+
##################
|
54 |
+
# Model settings
|
55 |
+
##################
|
56 |
+
|
57 |
+
DROPOUT 0.1
|
58 |
+
VOCAB_DIM 512
|
59 |
+
ROLE_SIZE 32
|
60 |
+
ROLE_DIM 16
|
61 |
+
POS_DIM 16
|
62 |
+
ENT_DIM 16
|
63 |
+
|
64 |
+
USE_ROLE
|
65 |
+
USE_POSENT
|
66 |
+
|
67 |
+
USE_BOS_TOKEN
|
68 |
+
USE_EOS_TOKEN
|
69 |
+
|
70 |
+
TRANSFORMER_EMBED_DROPOUT 0.1
|
71 |
+
TRANSFORMER_RESIDUAL_DROPOUT 0.1
|
72 |
+
TRANSFORMER_ATTENTION_DROPOUT 0.1
|
73 |
+
TRANSFORMER_LAYER 6
|
74 |
+
TRANSFORMER_HEAD 8
|
75 |
+
TRANSFORMER_POS_DISCOUNT 80
|
76 |
+
|
77 |
+
PRE_TOKENIZER TransfoXLTokenizer
|
78 |
+
PRE_TOKENIZER_PATH ../../../third_party/HMNet/ExampleInitModel/transfo-xl-wt103
|
79 |
+
PYLEARN_MODEL ../../../third_party/HMNet/ExampleInitModel/AMI-finetuned
|
80 |
+
# e.g. PYLEARN_MODEL conf_hmnet_AMI_conf~/run_1/11600
|
81 |
+
|
82 |
+
##################
|
83 |
+
# Tokenizer settings
|
84 |
+
##################
|
85 |
+
|
86 |
+
EXTRA_IDS 1000
|
87 |
+
|
88 |
+
##################
|
89 |
+
# Decoding settings
|
90 |
+
##################
|
91 |
+
|
92 |
+
BEAM_WIDTH 6
|
93 |
+
EVAL_TOKENIZED
|
94 |
+
EVAL_LOWERCASE
|
95 |
+
# MAX_GEN_LENGTH 300
|
96 |
+
MAX_GEN_LENGTH 60
|
97 |
+
MIN_GEN_LENGTH 10
|
98 |
+
NO_REPEAT_NGRAM_SIZE 3
|
model/dialogue/hmnet_model.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.base_model import SummModel
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import gzip
|
6 |
+
import json
|
7 |
+
from model.third_party.HMNet.Models.Trainers.HMNetTrainer import HMNetTrainer
|
8 |
+
from model.third_party.HMNet.Utils.Arguments import Arguments
|
9 |
+
|
10 |
+
import spacy
|
11 |
+
|
12 |
+
nlp = spacy.load("en_core_web_sm", disable=["parser"])
|
13 |
+
# tagger = nlp.get_pipe('tagger')
|
14 |
+
# ner = nlp.get_pipe('ner')
|
15 |
+
# POS = {w: i for i, w in enumerate([''] + list(tagger.labels))}
|
16 |
+
# ENT = {w: i for i, w in enumerate([''] + list(ner.move_names))}
|
17 |
+
# These two dicts are adapted from SpaCy 2.3.1, since HMNet's embedding for POS and ENT is fixed
|
18 |
+
POS = {
|
19 |
+
"": 0,
|
20 |
+
"$": 1,
|
21 |
+
"''": 2,
|
22 |
+
",": 3,
|
23 |
+
"-LRB-": 4,
|
24 |
+
"-RRB-": 5,
|
25 |
+
".": 6,
|
26 |
+
":": 7,
|
27 |
+
"ADD": 8,
|
28 |
+
"AFX": 9,
|
29 |
+
"CC": 10,
|
30 |
+
"CD": 11,
|
31 |
+
"DT": 12,
|
32 |
+
"EX": 13,
|
33 |
+
"FW": 14,
|
34 |
+
"HYPH": 15,
|
35 |
+
"IN": 16,
|
36 |
+
"JJ": 17,
|
37 |
+
"JJR": 18,
|
38 |
+
"JJS": 19,
|
39 |
+
"LS": 20,
|
40 |
+
"MD": 21,
|
41 |
+
"NFP": 22,
|
42 |
+
"NN": 23,
|
43 |
+
"NNP": 24,
|
44 |
+
"NNPS": 25,
|
45 |
+
"NNS": 26,
|
46 |
+
"PDT": 27,
|
47 |
+
"POS": 28,
|
48 |
+
"PRP": 29,
|
49 |
+
"PRP$": 30,
|
50 |
+
"RB": 31,
|
51 |
+
"RBR": 32,
|
52 |
+
"RBS": 33,
|
53 |
+
"RP": 34,
|
54 |
+
"SYM": 35,
|
55 |
+
"TO": 36,
|
56 |
+
"UH": 37,
|
57 |
+
"VB": 38,
|
58 |
+
"VBD": 39,
|
59 |
+
"VBG": 40,
|
60 |
+
"VBN": 41,
|
61 |
+
"VBP": 42,
|
62 |
+
"VBZ": 43,
|
63 |
+
"WDT": 44,
|
64 |
+
"WP": 45,
|
65 |
+
"WP$": 46,
|
66 |
+
"WRB": 47,
|
67 |
+
"XX": 48,
|
68 |
+
"_SP": 49,
|
69 |
+
"``": 50,
|
70 |
+
}
|
71 |
+
ENT = {
|
72 |
+
"": 0,
|
73 |
+
"B-ORG": 1,
|
74 |
+
"B-DATE": 2,
|
75 |
+
"B-PERSON": 3,
|
76 |
+
"B-GPE": 4,
|
77 |
+
"B-MONEY": 5,
|
78 |
+
"B-CARDINAL": 6,
|
79 |
+
"B-NORP": 7,
|
80 |
+
"B-PERCENT": 8,
|
81 |
+
"B-WORK_OF_ART": 9,
|
82 |
+
"B-LOC": 10,
|
83 |
+
"B-TIME": 11,
|
84 |
+
"B-QUANTITY": 12,
|
85 |
+
"B-FAC": 13,
|
86 |
+
"B-EVENT": 14,
|
87 |
+
"B-ORDINAL": 15,
|
88 |
+
"B-PRODUCT": 16,
|
89 |
+
"B-LAW": 17,
|
90 |
+
"B-LANGUAGE": 18,
|
91 |
+
"I-ORG": 19,
|
92 |
+
"I-DATE": 20,
|
93 |
+
"I-PERSON": 21,
|
94 |
+
"I-GPE": 22,
|
95 |
+
"I-MONEY": 23,
|
96 |
+
"I-CARDINAL": 24,
|
97 |
+
"I-NORP": 25,
|
98 |
+
"I-PERCENT": 26,
|
99 |
+
"I-WORK_OF_ART": 27,
|
100 |
+
"I-LOC": 28,
|
101 |
+
"I-TIME": 29,
|
102 |
+
"I-QUANTITY": 30,
|
103 |
+
"I-FAC": 31,
|
104 |
+
"I-EVENT": 32,
|
105 |
+
"I-ORDINAL": 33,
|
106 |
+
"I-PRODUCT": 34,
|
107 |
+
"I-LAW": 35,
|
108 |
+
"I-LANGUAGE": 36,
|
109 |
+
"L-ORG": 37,
|
110 |
+
"L-DATE": 38,
|
111 |
+
"L-PERSON": 39,
|
112 |
+
"L-GPE": 40,
|
113 |
+
"L-MONEY": 41,
|
114 |
+
"L-CARDINAL": 42,
|
115 |
+
"L-NORP": 43,
|
116 |
+
"L-PERCENT": 44,
|
117 |
+
"L-WORK_OF_ART": 45,
|
118 |
+
"L-LOC": 46,
|
119 |
+
"L-TIME": 47,
|
120 |
+
"L-QUANTITY": 48,
|
121 |
+
"L-FAC": 49,
|
122 |
+
"L-EVENT": 50,
|
123 |
+
"L-ORDINAL": 51,
|
124 |
+
"L-PRODUCT": 52,
|
125 |
+
"L-LAW": 53,
|
126 |
+
"L-LANGUAGE": 54,
|
127 |
+
"U-ORG": 55,
|
128 |
+
"U-DATE": 56,
|
129 |
+
"U-PERSON": 57,
|
130 |
+
"U-GPE": 58,
|
131 |
+
"U-MONEY": 59,
|
132 |
+
"U-CARDINAL": 60,
|
133 |
+
"U-NORP": 61,
|
134 |
+
"U-PERCENT": 62,
|
135 |
+
"U-WORK_OF_ART": 63,
|
136 |
+
"U-LOC": 64,
|
137 |
+
"U-TIME": 65,
|
138 |
+
"U-QUANTITY": 66,
|
139 |
+
"U-FAC": 67,
|
140 |
+
"U-EVENT": 68,
|
141 |
+
"U-ORDINAL": 69,
|
142 |
+
"U-PRODUCT": 70,
|
143 |
+
"U-LAW": 71,
|
144 |
+
"U-LANGUAGE": 72,
|
145 |
+
"O": 73,
|
146 |
+
}
|
147 |
+
|
148 |
+
|
149 |
+
class HMNetModel(SummModel):
|
150 |
+
# static variables
|
151 |
+
model_name = "HMNET"
|
152 |
+
is_extractive = False
|
153 |
+
is_neural = True
|
154 |
+
is_dialogue_based = True
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
min_gen_length: int = 10,
|
159 |
+
max_gen_length: int = 300,
|
160 |
+
beam_width: int = 6,
|
161 |
+
**kwargs,
|
162 |
+
):
|
163 |
+
"""
|
164 |
+
Create a summarization model with HMNet backbone. In the default setting, the inference speed will be
|
165 |
+
10s/sample (on one GPU), however, if one can tune these three parameters properly, e.g. min_gen_length=10,
|
166 |
+
max_gen_length=100, and beam_width=2, the inference speed will increase to 2s/sample (on one GPU).
|
167 |
+
|
168 |
+
Args:
|
169 |
+
min_gen_length (int): minimum generation length of the decoder
|
170 |
+
max_gen_length (int): maximum generation length of the decoder
|
171 |
+
beam_width (int): width of the beam when doing beam search in the decoding process
|
172 |
+
kwargs: the other valid parameters. The valid parameters can be found in
|
173 |
+
model/dialogue/hmnet/config/dialogue.conf . You can use either lower case or upper case for parameter
|
174 |
+
name. The valid parameter name is one of the following args, however, we do not encourage you to modify
|
175 |
+
them, since some unexpected, untested errors might be triggered:
|
176 |
+
['MODEL', 'TASK', 'CRITERION', 'SEED', 'MAX_NUM_EPOCHS', 'EVAL_PER_UPDATE_NUM'
|
177 |
+
, 'UPDATES_PER_EPOCH', 'OPTIMIZER', 'START_LEARNING_RATE', 'LR_SCHEDULER', 'WARMUP_STEPS',
|
178 |
+
'WARMUP_INIT_LR', 'WARMUP_END_LR', 'GRADIENT_ACCUMULATE_STEP', 'GRAD_CLIPPING', 'USE_REL_DATA_PATH',
|
179 |
+
'TRAIN_FILE', 'DEV_FILE', 'TEST_FILE', 'ROLE_DICT_FILE', 'MINI_BATCH', 'MAX_PADDING_RATIO',
|
180 |
+
'BATCH_READ_AHEAD', 'DOC_SHUFFLE_BUF_SIZE', 'SAMPLE_SHUFFLE_BUFFER_SIZE', 'BATCH_SHUFFLE_BUFFER_SIZE',
|
181 |
+
'MAX_TRANSCRIPT_WORD', 'MAX_SENT_LEN', 'MAX_SENT_NUM', 'DROPOUT', 'VOCAB_DIM', 'ROLE_SIZE', 'ROLE_DIM',
|
182 |
+
'POS_DIM', 'ENT_DIM', 'USE_ROLE', 'USE_POSENT', 'USE_BOS_TOKEN', 'USE_EOS_TOKEN',
|
183 |
+
'TRANSFORMER_EMBED_DROPOUT', 'TRANSFORMER_RESIDUAL_DROPOUT', 'TRANSFORMER_ATTENTION_DROPOUT',
|
184 |
+
'TRANSFORMER_LAYER', 'TRANSFORMER_HEAD', 'TRANSFORMER_POS_DISCOUNT', 'PRE_TOKENIZER',
|
185 |
+
'PRE_TOKENIZER_PATH', 'PYLEARN_MODEL', 'EXTRA_IDS', 'BEAM_WIDTH', 'EVAL_TOKENIZED', 'EVAL_LOWERCASE',
|
186 |
+
'MAX_GEN_LENGTH', 'MIN_GEN_LENGTH', 'NO_REPEAT_NGRAM_SIZE']
|
187 |
+
|
188 |
+
Return an instance of HMNet model for dialogue summarization.
|
189 |
+
"""
|
190 |
+
super(HMNetModel, self).__init__()
|
191 |
+
self.root_path = self._get_root()
|
192 |
+
|
193 |
+
# we leave the most influential params with prompt and the others as hidden kwargs
|
194 |
+
kwargs["MIN_GEN_LENGTH"] = min_gen_length
|
195 |
+
kwargs["MAX_GEN_LENGTH"] = max_gen_length
|
196 |
+
kwargs["BEAM_WIDTH"] = beam_width
|
197 |
+
self.opt = self._parse_args(kwargs)
|
198 |
+
self.model = HMNetTrainer(self.opt)
|
199 |
+
|
200 |
+
def _get_root(self):
|
201 |
+
root_path = os.getcwd()
|
202 |
+
while "model" not in os.listdir(root_path):
|
203 |
+
root_path = os.path.dirname(root_path)
|
204 |
+
root_path = os.path.join(root_path, "model/dialogue")
|
205 |
+
return root_path
|
206 |
+
|
207 |
+
def _parse_args(self, kwargs):
|
208 |
+
parser = argparse.ArgumentParser(
|
209 |
+
description="HMNet: Pretrain or fine-tune models for HMNet model."
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
"--command", default="evaluate", help="Command: train/evaluate"
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--conf_file",
|
216 |
+
default=os.path.join(self.root_path, "hmnet/config/dialogue.conf"),
|
217 |
+
help="Path to the BigLearn conf file.",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--PYLEARN_MODEL", help="Overrides this option from the conf file."
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--master_port", help="Overrides this option default", default=None
|
224 |
+
)
|
225 |
+
parser.add_argument("--cluster", help="local, philly or aml", default="local")
|
226 |
+
parser.add_argument(
|
227 |
+
"--dist_init_path", help="Distributed init path for AML", default="./tmp"
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--fp16",
|
231 |
+
action="store_true",
|
232 |
+
help="Whether to use 16-bit float precision instead of 32-bit",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--fp16_opt_level",
|
236 |
+
type=str,
|
237 |
+
default="O1",
|
238 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
239 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
240 |
+
)
|
241 |
+
parser.add_argument("--no_cuda", action="store_true", help="Disable cuda.")
|
242 |
+
parser.add_argument(
|
243 |
+
"--config_overrides",
|
244 |
+
help="Override parameters on config, VAR=val;VAR=val;...",
|
245 |
+
)
|
246 |
+
|
247 |
+
cmdline_args = parser.parse_args()
|
248 |
+
command = cmdline_args.command
|
249 |
+
conf_file = cmdline_args.conf_file
|
250 |
+
conf_args = Arguments(conf_file)
|
251 |
+
opt = conf_args.readArguments()
|
252 |
+
|
253 |
+
if cmdline_args.config_overrides:
|
254 |
+
for config_override in cmdline_args.config_overrides.split(";"):
|
255 |
+
config_override = config_override.strip()
|
256 |
+
if config_override:
|
257 |
+
var_val = config_override.split("=")
|
258 |
+
assert (
|
259 |
+
len(var_val) == 2
|
260 |
+
), f"Config override '{var_val}' does not have the form 'VAR=val'"
|
261 |
+
conf_args.add_opt(opt, var_val[0], var_val[1], force_override=True)
|
262 |
+
|
263 |
+
opt["cuda"] = torch.cuda.is_available() and not cmdline_args.no_cuda
|
264 |
+
opt["confFile"] = conf_file
|
265 |
+
if "datadir" not in opt:
|
266 |
+
opt["datadir"] = os.path.dirname(
|
267 |
+
conf_file
|
268 |
+
) # conf_file specifies where the data folder is
|
269 |
+
opt["basename"] = os.path.basename(
|
270 |
+
conf_file
|
271 |
+
) # conf_file specifies where the name of save folder is
|
272 |
+
opt["command"] = command
|
273 |
+
|
274 |
+
# combine cmdline_args into opt dictionary
|
275 |
+
for key, val in cmdline_args.__dict__.items():
|
276 |
+
# if val is not None and key not in ['command', 'conf_file']:
|
277 |
+
if val is not None:
|
278 |
+
opt[key] = val
|
279 |
+
|
280 |
+
# combine kwargs into opt dictionary (we allow lower case)
|
281 |
+
for key, val in kwargs.items():
|
282 |
+
valid_keys = [x for x in opt.keys() if x.upper() == x]
|
283 |
+
if key.upper() not in valid_keys:
|
284 |
+
print("WARNING: {} is not a valid key in HMNet.".format(key))
|
285 |
+
print("The valid keys are:", valid_keys)
|
286 |
+
continue
|
287 |
+
if val is not None:
|
288 |
+
opt[key.upper()] = val
|
289 |
+
|
290 |
+
return opt
|
291 |
+
|
292 |
+
def summarize(self, corpus, queries=None):
|
293 |
+
print(f"HMNet model: processing document of {corpus.__len__()} samples")
|
294 |
+
# transform the original dataset to "dialogue" input
|
295 |
+
# we only use test set path for evaluation
|
296 |
+
data_folder = os.path.join(
|
297 |
+
os.path.dirname(self.opt["datadir"]),
|
298 |
+
"ExampleRawData/meeting_summarization/AMI_proprec/test",
|
299 |
+
)
|
300 |
+
|
301 |
+
self._create_datafolder(data_folder)
|
302 |
+
self._preprocess(corpus, data_folder)
|
303 |
+
|
304 |
+
# return self.model.eval()
|
305 |
+
results = self._evaluate()
|
306 |
+
|
307 |
+
return results
|
308 |
+
|
309 |
+
def _evaluate(self):
|
310 |
+
if self.opt["rank"] == 0:
|
311 |
+
self.model.log("-----------------------------------------------")
|
312 |
+
self.model.log("Evaluating model ... ")
|
313 |
+
|
314 |
+
self.model.set_up_model()
|
315 |
+
|
316 |
+
eval_dataset = "test"
|
317 |
+
batch_generator_eval = self.model.get_batch_generator(eval_dataset)
|
318 |
+
predictions = self._eval_batches(
|
319 |
+
self.model.module, batch_generator_eval, self.model.saveFolder, eval_dataset
|
320 |
+
)
|
321 |
+
|
322 |
+
return predictions
|
323 |
+
|
324 |
+
def _eval_batches(self, module, dev_batches, save_folder, label=""):
|
325 |
+
max_sent_len = int(self.opt["MAX_GEN_LENGTH"])
|
326 |
+
|
327 |
+
print("Decoding current model ... \nSaving folder is {}".format(save_folder))
|
328 |
+
print("Each sample will cost about 10 second.")
|
329 |
+
import time
|
330 |
+
|
331 |
+
start_time = time.time()
|
332 |
+
predictions = [] # prediction of tokens from model
|
333 |
+
if not isinstance(module.tokenizer, list):
|
334 |
+
decoder_tokenizer = module.tokenizer
|
335 |
+
elif len(module.tokenizer) == 1:
|
336 |
+
decoder_tokenizer = module.tokenizer[0]
|
337 |
+
elif len(module.tokenizer) == 2:
|
338 |
+
decoder_tokenizer = module.tokenizer[1]
|
339 |
+
else:
|
340 |
+
assert False, "len(module.tokenizer) > 2"
|
341 |
+
|
342 |
+
with torch.no_grad():
|
343 |
+
for j, dev_batch in enumerate(dev_batches):
|
344 |
+
for b in dev_batch:
|
345 |
+
if torch.is_tensor(dev_batch[b]):
|
346 |
+
dev_batch[b] = dev_batch[b].to(self.opt["device"])
|
347 |
+
|
348 |
+
beam_search_res = module(
|
349 |
+
dev_batch, beam_search=True, max_sent_len=max_sent_len
|
350 |
+
)
|
351 |
+
pred = [
|
352 |
+
[t[0] for t in x] if len(x) > 0 else [[]] for x in beam_search_res
|
353 |
+
]
|
354 |
+
predictions.extend(
|
355 |
+
[
|
356 |
+
[
|
357 |
+
self._convert_tokens_to_string(decoder_tokenizer, tt)
|
358 |
+
for tt in t
|
359 |
+
]
|
360 |
+
for t in pred
|
361 |
+
]
|
362 |
+
)
|
363 |
+
|
364 |
+
if (
|
365 |
+
"DEBUG" in self.opt and j >= 10
|
366 |
+
) or j >= self.model.task.evaluator.eval_batches_num:
|
367 |
+
# in debug mode (decode first 10 batches) ortherwise decode first self.eval_batches_num bathes
|
368 |
+
break
|
369 |
+
|
370 |
+
top1_predictions = [x[0] for x in predictions]
|
371 |
+
|
372 |
+
print("Total time for inference:", time.time() - start_time)
|
373 |
+
return top1_predictions
|
374 |
+
|
375 |
+
def _convert_tokens_to_string(self, tokenizer, tokens):
|
376 |
+
if "EVAL_TOKENIZED" in self.opt:
|
377 |
+
tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
|
378 |
+
if "EVAL_LOWERCASE" in self.opt:
|
379 |
+
tokens = [t.lower() for t in tokens]
|
380 |
+
if "EVAL_TOKENIZED" in self.opt:
|
381 |
+
return " ".join(tokens)
|
382 |
+
else:
|
383 |
+
return tokenizer.decode(
|
384 |
+
tokenizer.convert_tokens_to_ids(tokens), skip_special_tokens=True
|
385 |
+
)
|
386 |
+
|
387 |
+
def _preprocess(self, corpus, test_path):
|
388 |
+
samples = []
|
389 |
+
for i, sample in enumerate(corpus):
|
390 |
+
new_sample = {"id": i, "meeting": [], "summary": []}
|
391 |
+
if isinstance(sample, str):
|
392 |
+
raise RuntimeError(
|
393 |
+
"Error: the input of HMNet should be dialogues, rather than documents."
|
394 |
+
)
|
395 |
+
|
396 |
+
# add all the turns one by one
|
397 |
+
for turn in sample:
|
398 |
+
turn = [x.strip() for x in turn.split(":")]
|
399 |
+
if len(turn) < 2:
|
400 |
+
continue
|
401 |
+
tokenized_turn = nlp(turn[1])
|
402 |
+
# In case we can't find proper entity in move_names
|
403 |
+
ent_id = []
|
404 |
+
pos_id = []
|
405 |
+
for token in tokenized_turn:
|
406 |
+
ent = (
|
407 |
+
token.ent_iob_ + "-" + token.ent_type_
|
408 |
+
if token.ent_iob_ != "O"
|
409 |
+
else "O"
|
410 |
+
)
|
411 |
+
ent_id.append(ENT[ent] if ent in ENT else ENT[""])
|
412 |
+
|
413 |
+
pos = token.tag_
|
414 |
+
pos_id.append(POS[pos] if pos in POS else POS[""])
|
415 |
+
|
416 |
+
new_sample["meeting"].append(
|
417 |
+
{
|
418 |
+
"speaker": turn[0],
|
419 |
+
"role": "",
|
420 |
+
"utt": {
|
421 |
+
"word": [str(token) for token in tokenized_turn],
|
422 |
+
"pos_id": pos_id,
|
423 |
+
"ent_id": ent_id,
|
424 |
+
},
|
425 |
+
}
|
426 |
+
)
|
427 |
+
new_sample["summary"].append(
|
428 |
+
"This is a dummy summary. HMNet will filter out the sample w/o summary!"
|
429 |
+
)
|
430 |
+
samples.append(new_sample)
|
431 |
+
# save to the gzip
|
432 |
+
file_path = os.path.join(test_path, "split_{}.jsonl.gz".format(i))
|
433 |
+
with gzip.open(file_path, "wt", encoding="utf-8") as file:
|
434 |
+
file.write(json.dumps(new_sample))
|
435 |
+
|
436 |
+
def _clean_datafolder(self, data_folder):
|
437 |
+
for name in os.listdir(data_folder):
|
438 |
+
name = os.path.join(data_folder, name)
|
439 |
+
if ".gz" in name:
|
440 |
+
os.remove(name)
|
441 |
+
|
442 |
+
def _create_datafolder(self, data_folder):
|
443 |
+
if os.path.exists(data_folder):
|
444 |
+
self._clean_datafolder(data_folder)
|
445 |
+
else:
|
446 |
+
os.makedirs(data_folder)
|
447 |
+
with open(
|
448 |
+
os.path.join(os.path.dirname(data_folder), "test_ami.json"),
|
449 |
+
"w",
|
450 |
+
encoding="utf-8",
|
451 |
+
) as file:
|
452 |
+
json.dump(
|
453 |
+
[
|
454 |
+
{
|
455 |
+
"source": {
|
456 |
+
"dataset": "../ExampleRawData/meeting_summarization/AMI_proprec/test/"
|
457 |
+
},
|
458 |
+
"task": "meeting",
|
459 |
+
"name": "ami",
|
460 |
+
}
|
461 |
+
],
|
462 |
+
file,
|
463 |
+
)
|
464 |
+
|
465 |
+
with open(
|
466 |
+
os.path.join(
|
467 |
+
os.path.dirname(os.path.dirname(data_folder)), "role_dict_ext.json"
|
468 |
+
),
|
469 |
+
"w",
|
470 |
+
) as file:
|
471 |
+
json.dump({}, file)
|
472 |
+
|
473 |
+
@classmethod
|
474 |
+
def show_capability(cls) -> None:
|
475 |
+
basic_description = cls.generate_basic_description()
|
476 |
+
more_details = (
|
477 |
+
"A HMNet model finetuned on CNN-DM dataset for summarization.\n\n"
|
478 |
+
"Strengths:\n - High performance on dialogue summarization task.\n\n"
|
479 |
+
"Weaknesses:\n - Not suitable for datasets other than dialogues.\n\n"
|
480 |
+
"Initialization arguments:\n "
|
481 |
+
" - `corpus`: Unlabelled corpus of documents.\n"
|
482 |
+
)
|
483 |
+
print(f"{basic_description} \n {'#' * 20} \n {more_details}")
|
model/multi_doc/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .multi_doc_joint_model import MultiDocJointModel
|
2 |
+
from .multi_doc_separate_model import MultiDocSeparateModel
|
model/multi_doc/base_multi_doc_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.base_model import SummModel
|
2 |
+
|
3 |
+
|
4 |
+
class MultiDocSummModel(SummModel):
|
5 |
+
|
6 |
+
is_multi_document = True
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
trained_domain: str = None,
|
11 |
+
max_input_length: int = None,
|
12 |
+
max_output_length: int = None,
|
13 |
+
):
|
14 |
+
super(MultiDocSummModel, self).__init__(
|
15 |
+
trained_domain=trained_domain,
|
16 |
+
max_input_length=max_input_length,
|
17 |
+
max_output_length=max_output_length,
|
18 |
+
)
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def assert_summ_input_type(cls, corpus, query):
|
22 |
+
if not all(
|
23 |
+
[
|
24 |
+
isinstance(ins, list) and all([isinstance(doc, str) for doc in ins])
|
25 |
+
for ins in corpus
|
26 |
+
]
|
27 |
+
):
|
28 |
+
raise TypeError(
|
29 |
+
"Multi-document summarization models summarize instances of multiple documents (`List[List[str]]`)."
|
30 |
+
)
|
31 |
+
|
32 |
+
if query is not None:
|
33 |
+
if not isinstance(query, list):
|
34 |
+
raise TypeError(
|
35 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
36 |
+
)
|
37 |
+
if not all([isinstance(q, str) for q in query]):
|
38 |
+
raise TypeError(
|
39 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
40 |
+
)
|
model/multi_doc/multi_doc_joint_model.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_multi_doc_model import MultiDocSummModel
|
2 |
+
from model.base_model import SummModel
|
3 |
+
from model.single_doc import TextRankModel
|
4 |
+
from typing import Union, List
|
5 |
+
|
6 |
+
|
7 |
+
class MultiDocJointModel(MultiDocSummModel):
|
8 |
+
|
9 |
+
model_name = "Multi-document joint"
|
10 |
+
is_multi_document = True
|
11 |
+
|
12 |
+
def __init__(self, model_backend: SummModel = TextRankModel, **kwargs):
|
13 |
+
super(MultiDocJointModel, self).__init__()
|
14 |
+
model = model_backend(**kwargs)
|
15 |
+
self.model = model
|
16 |
+
|
17 |
+
def summarize(
|
18 |
+
self,
|
19 |
+
corpus: Union[List[str], List[List[str]]],
|
20 |
+
query: Union[List[str], List[List[str]]] = None,
|
21 |
+
) -> List[str]:
|
22 |
+
self.assert_summ_input_type(corpus, None)
|
23 |
+
joint_corpus = []
|
24 |
+
for instance in corpus:
|
25 |
+
joint_corpus.append(" ".join(instance))
|
26 |
+
|
27 |
+
summaries = self.model.summarize(joint_corpus)
|
28 |
+
|
29 |
+
return summaries
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def generate_basic_description(cls) -> str:
|
33 |
+
basic_description = (
|
34 |
+
"MultiDocJointModel performs multi-document summarization by"
|
35 |
+
" first concatenating all documents,"
|
36 |
+
" and then performing single-document summarization on the concatenation."
|
37 |
+
)
|
38 |
+
return basic_description
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def show_capability(cls):
|
42 |
+
basic_description = cls.generate_basic_description()
|
43 |
+
more_details = (
|
44 |
+
"A multi-document summarization model."
|
45 |
+
" Allows for custom model backend selection at initialization."
|
46 |
+
" Concatenates each document in corpus and returns single-document summarization of joint corpus.\n"
|
47 |
+
"Strengths: \n - Allows for control of backend model.\n"
|
48 |
+
"Weaknesses: \n - Assumes all documents are equally weighted.\n"
|
49 |
+
" - May fail to extract information from certain documents.\n"
|
50 |
+
)
|
51 |
+
print(f"{basic_description}\n{'#' * 20}\n{more_details}")
|
model/multi_doc/multi_doc_separate_model.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_multi_doc_model import MultiDocSummModel
|
2 |
+
from model.base_model import SummModel
|
3 |
+
from model.single_doc import TextRankModel
|
4 |
+
from typing import Union, List
|
5 |
+
|
6 |
+
|
7 |
+
class MultiDocSeparateModel(MultiDocSummModel):
|
8 |
+
|
9 |
+
model_name = "Multi-document separate"
|
10 |
+
is_multi_document = True
|
11 |
+
|
12 |
+
def __init__(self, model_backend: SummModel = TextRankModel, **kwargs):
|
13 |
+
super(MultiDocSeparateModel, self).__init__()
|
14 |
+
model = model_backend(**kwargs)
|
15 |
+
self.model = model
|
16 |
+
|
17 |
+
def summarize(
|
18 |
+
self,
|
19 |
+
corpus: Union[List[str], List[List[str]]],
|
20 |
+
query: Union[List[str], List[List[str]]] = None,
|
21 |
+
) -> List[str]:
|
22 |
+
self.assert_summ_input_type(corpus, None)
|
23 |
+
summaries = []
|
24 |
+
for instance in corpus:
|
25 |
+
instance_summaries = self.model.summarize(instance)
|
26 |
+
summaries.append(" ".join(instance_summaries))
|
27 |
+
|
28 |
+
return summaries
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def generate_basic_description(cls) -> str:
|
32 |
+
basic_description = (
|
33 |
+
"MultiDocSeparateModel performs multi-document summarization by"
|
34 |
+
" first performing single-document summarization on each document,"
|
35 |
+
" and then concatenating the results."
|
36 |
+
)
|
37 |
+
return basic_description
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def show_capability(cls):
|
41 |
+
basic_description = cls.generate_basic_description()
|
42 |
+
more_details = (
|
43 |
+
"A multi-document summarization model."
|
44 |
+
" Allows for custom model backend selection at initialization."
|
45 |
+
" Performs single-document summarization on each document in corpus and returns concatenated result.\n"
|
46 |
+
"Strengths: \n - Allows for control of backend model.\n"
|
47 |
+
"Weaknesses: \n - Assumes all documents are equally weighted.\n - May produce redundant information for similar documents.\n"
|
48 |
+
)
|
49 |
+
print(f"{basic_description}\n{'#' * 20}\n{more_details}")
|
model/query_based/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .bm25_model import BM25SummModel
|
2 |
+
from .tf_idf_model import TFIDFSummModel
|
model/query_based/base_query_based_model.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.base_model import SummModel
|
2 |
+
from model.single_doc import TextRankModel
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
from nltk import sent_tokenize, word_tokenize
|
6 |
+
from nltk.corpus import stopwords
|
7 |
+
from nltk.stem import PorterStemmer
|
8 |
+
|
9 |
+
|
10 |
+
class QueryBasedSummModel(SummModel):
|
11 |
+
|
12 |
+
is_query_based = True
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
trained_domain: str = None,
|
17 |
+
max_input_length: int = None,
|
18 |
+
max_output_length: int = None,
|
19 |
+
model_backend: SummModel = TextRankModel,
|
20 |
+
retrieval_ratio: float = 0.5,
|
21 |
+
preprocess: bool = True,
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
super(QueryBasedSummModel, self).__init__(
|
25 |
+
trained_domain=trained_domain,
|
26 |
+
max_input_length=max_input_length,
|
27 |
+
max_output_length=max_output_length,
|
28 |
+
)
|
29 |
+
self.model = model_backend(**kwargs)
|
30 |
+
self.retrieval_ratio = retrieval_ratio
|
31 |
+
self.preprocess = preprocess
|
32 |
+
|
33 |
+
def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]:
|
34 |
+
raise NotImplementedError()
|
35 |
+
|
36 |
+
def summarize(
|
37 |
+
self,
|
38 |
+
corpus: Union[List[str], List[List[str]]],
|
39 |
+
queries: List[str] = None,
|
40 |
+
) -> List[str]:
|
41 |
+
self.assert_summ_input_type(corpus, queries)
|
42 |
+
|
43 |
+
retrieval_output = [] # List[str]
|
44 |
+
for instance, query in zip(corpus, queries):
|
45 |
+
if isinstance(instance, str):
|
46 |
+
is_dialogue = False
|
47 |
+
instance = sent_tokenize(instance)
|
48 |
+
else:
|
49 |
+
is_dialogue = True
|
50 |
+
query = [query]
|
51 |
+
|
52 |
+
# instance & query now are List[str] for sure
|
53 |
+
if self.preprocess:
|
54 |
+
preprocessor = Preprocessor()
|
55 |
+
instance = preprocessor.preprocess(instance)
|
56 |
+
query = preprocessor.preprocess(query)
|
57 |
+
|
58 |
+
n_best = max(int(len(instance) * self.retrieval_ratio), 1)
|
59 |
+
top_n_sent = self._retrieve(instance, query, n_best)
|
60 |
+
|
61 |
+
if not is_dialogue:
|
62 |
+
top_n_sent = " ".join(top_n_sent) # str
|
63 |
+
retrieval_output.append(top_n_sent)
|
64 |
+
|
65 |
+
summaries = self.model.summarize(
|
66 |
+
retrieval_output
|
67 |
+
) # List[str] or List[List[str]]
|
68 |
+
return summaries
|
69 |
+
|
70 |
+
def generate_specific_description(self):
|
71 |
+
is_neural = self.model.is_neural & self.is_neural
|
72 |
+
is_extractive = self.model.is_extractive | self.is_extractive
|
73 |
+
model_name = "Pipeline with retriever: {}, summarizer: {}".format(
|
74 |
+
self.model_name, self.model.model_name
|
75 |
+
)
|
76 |
+
|
77 |
+
extractive_abstractive = "extractive" if is_extractive else "abstractive"
|
78 |
+
neural = "neural" if is_neural else "non-neural"
|
79 |
+
|
80 |
+
basic_description = (
|
81 |
+
f"{model_name} is a "
|
82 |
+
f"{'query-based' if self.is_query_based else ''} "
|
83 |
+
f"{extractive_abstractive}, {neural} model for summarization."
|
84 |
+
)
|
85 |
+
|
86 |
+
return basic_description
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def assert_summ_input_type(cls, corpus, query):
|
90 |
+
if query is None:
|
91 |
+
raise TypeError(
|
92 |
+
"Query-based summarization models summarize instances of query-text pairs, however, query is missing."
|
93 |
+
)
|
94 |
+
|
95 |
+
if not isinstance(query, list):
|
96 |
+
raise TypeError(
|
97 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
98 |
+
)
|
99 |
+
if not all([isinstance(q, str) for q in query]):
|
100 |
+
raise TypeError(
|
101 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
102 |
+
)
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def generate_basic_description(cls) -> str:
|
106 |
+
basic_description = (
|
107 |
+
"QueryBasedSummModel performs query-based summarization. Given a query-text pair,"
|
108 |
+
"the model will first extract the most relevant sentences in articles or turns in "
|
109 |
+
"dialogues, then use the single document summarization model to generate the summary"
|
110 |
+
)
|
111 |
+
return basic_description
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def show_capability(cls):
|
115 |
+
basic_description = cls.generate_basic_description()
|
116 |
+
more_details = (
|
117 |
+
"A query-based summarization model."
|
118 |
+
" Allows for custom model backend selection at initialization."
|
119 |
+
" Retrieve relevant turns and then summarize the retrieved turns\n"
|
120 |
+
"Strengths: \n - Allows for control of backend model.\n"
|
121 |
+
"Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n"
|
122 |
+
)
|
123 |
+
print(f"{basic_description}\n{'#' * 20}\n{more_details}")
|
124 |
+
|
125 |
+
|
126 |
+
class Preprocessor:
|
127 |
+
def __init__(self, remove_stopwords=True, lower_case=True, stem=False):
|
128 |
+
self.sw = stopwords.words("english")
|
129 |
+
self.stemmer = PorterStemmer()
|
130 |
+
self.remove_stopwords = remove_stopwords
|
131 |
+
self.lower_case = lower_case
|
132 |
+
self.stem = stem
|
133 |
+
|
134 |
+
def preprocess(self, corpus: List[str]) -> List[str]:
|
135 |
+
if self.lower_case:
|
136 |
+
corpus = [sent.lower() for sent in corpus]
|
137 |
+
tokenized_corpus = [word_tokenize(sent) for sent in corpus]
|
138 |
+
if self.remove_stopwords:
|
139 |
+
tokenized_corpus = [
|
140 |
+
[word for word in sent if word not in self.sw]
|
141 |
+
for sent in tokenized_corpus
|
142 |
+
]
|
143 |
+
if self.stem:
|
144 |
+
tokenized_corpus = [
|
145 |
+
[self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus
|
146 |
+
]
|
147 |
+
return [" ".join(sent) for sent in tokenized_corpus]
|
model/query_based/bm25_model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_query_based_model import QueryBasedSummModel
|
2 |
+
from model.base_model import SummModel
|
3 |
+
from model.single_doc import TextRankModel
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from gensim.summarization.bm25 import BM25
|
7 |
+
from nltk import word_tokenize
|
8 |
+
|
9 |
+
|
10 |
+
class BM25SummModel(QueryBasedSummModel):
|
11 |
+
|
12 |
+
# static variables
|
13 |
+
model_name = "BM25"
|
14 |
+
is_extractive = True # only represents the retrieval part
|
15 |
+
is_neural = False # only represents the retrieval part
|
16 |
+
is_query_based = True
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
trained_domain: str = None,
|
21 |
+
max_input_length: int = None,
|
22 |
+
max_output_length: int = None,
|
23 |
+
model_backend: SummModel = TextRankModel,
|
24 |
+
retrieval_ratio: float = 0.5,
|
25 |
+
preprocess: bool = True,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
super(BM25SummModel, self).__init__(
|
29 |
+
trained_domain=trained_domain,
|
30 |
+
max_input_length=max_input_length,
|
31 |
+
max_output_length=max_output_length,
|
32 |
+
model_backend=model_backend,
|
33 |
+
retrieval_ratio=retrieval_ratio,
|
34 |
+
preprocess=preprocess,
|
35 |
+
**kwargs
|
36 |
+
)
|
37 |
+
|
38 |
+
def _retrieve(self, instance: List[str], query: List[str], n_best):
|
39 |
+
bm25 = BM25(word_tokenize(s) for s in instance)
|
40 |
+
scores = bm25.get_scores(query)
|
41 |
+
best_sent_ind = sorted(
|
42 |
+
range(len(scores)), key=lambda i: scores[i], reverse=True
|
43 |
+
)[:n_best]
|
44 |
+
top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)]
|
45 |
+
return top_n_sent
|
model/query_based/tf_idf_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_query_based_model import QueryBasedSummModel
|
2 |
+
from model.base_model import SummModel
|
3 |
+
from model.single_doc import TextRankModel
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
|
9 |
+
|
10 |
+
class TFIDFSummModel(QueryBasedSummModel):
|
11 |
+
|
12 |
+
# static variables
|
13 |
+
model_name = "TF-IDF"
|
14 |
+
is_extractive = True
|
15 |
+
is_neural = False
|
16 |
+
is_query_based = True
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
trained_domain: str = None,
|
21 |
+
max_input_length: int = None,
|
22 |
+
max_output_length: int = None,
|
23 |
+
model_backend: SummModel = TextRankModel,
|
24 |
+
retrieval_ratio: float = 0.5,
|
25 |
+
preprocess: bool = True,
|
26 |
+
**kwargs
|
27 |
+
):
|
28 |
+
super(TFIDFSummModel, self).__init__(
|
29 |
+
trained_domain=trained_domain,
|
30 |
+
max_input_length=max_input_length,
|
31 |
+
max_output_length=max_output_length,
|
32 |
+
model_backend=model_backend,
|
33 |
+
retrieval_ratio=retrieval_ratio,
|
34 |
+
preprocess=preprocess,
|
35 |
+
**kwargs
|
36 |
+
)
|
37 |
+
self.vectorizer = TfidfVectorizer()
|
38 |
+
|
39 |
+
def _retrieve(self, instance: List[str], query: List[str], n_best):
|
40 |
+
instance_vectors = self.vectorizer.fit_transform(instance)
|
41 |
+
query_vector = self.vectorizer.transform(query)
|
42 |
+
|
43 |
+
similarities = cosine_similarity(query_vector, instance_vectors).squeeze()
|
44 |
+
top_n_index = similarities.argsort()[::-1][0:n_best]
|
45 |
+
top_n_sent = [instance[ind] for ind in top_n_index] # List[str]
|
46 |
+
return top_n_sent
|
model/single_doc/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bart_model import BartModel
|
2 |
+
from .pegasus_model import PegasusModel
|
3 |
+
from .lexrank_model import LexRankModel
|
4 |
+
from .longformer_model import LongformerModel
|
5 |
+
from .textrank_model import TextRankModel
|
model/single_doc/bart_model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
2 |
+
from .base_single_doc_model import SingleDocSummModel
|
3 |
+
|
4 |
+
|
5 |
+
class BartModel(SingleDocSummModel):
|
6 |
+
|
7 |
+
# static variables
|
8 |
+
model_name = "BART"
|
9 |
+
is_extractive = False
|
10 |
+
is_neural = False
|
11 |
+
|
12 |
+
def __init__(self, device="cpu"):
|
13 |
+
super(BartModel, self).__init__()
|
14 |
+
|
15 |
+
self.device = device
|
16 |
+
model_name = "facebook/bart-large-cnn"
|
17 |
+
self.tokenizer = BartTokenizer.from_pretrained(model_name)
|
18 |
+
self.model = BartForConditionalGeneration.from_pretrained(model_name)
|
19 |
+
|
20 |
+
def summarize(self, corpus, queries=None):
|
21 |
+
self.assert_summ_input_type(corpus, queries)
|
22 |
+
|
23 |
+
batch = self.tokenizer(
|
24 |
+
corpus, truncation=True, padding="longest", return_tensors="pt"
|
25 |
+
).to(self.device)
|
26 |
+
encoded_summaries = self.model.generate(**batch)
|
27 |
+
summaries = self.tokenizer.batch_decode(
|
28 |
+
encoded_summaries, skip_special_tokens=True
|
29 |
+
)
|
30 |
+
|
31 |
+
return summaries
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def show_capability(cls) -> None:
|
35 |
+
# TODO zhangir: add the show capability function for BART
|
36 |
+
print(cls.generate_basic_description())
|
model/single_doc/base_single_doc_model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.base_model import SummModel
|
2 |
+
|
3 |
+
|
4 |
+
class SingleDocSummModel(SummModel):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
trained_domain: str = None,
|
8 |
+
max_input_length: int = None,
|
9 |
+
max_output_length: int = None,
|
10 |
+
):
|
11 |
+
super(SingleDocSummModel, self).__init__(
|
12 |
+
trained_domain=trained_domain,
|
13 |
+
max_input_length=max_input_length,
|
14 |
+
max_output_length=max_output_length,
|
15 |
+
)
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def assert_summ_input_type(cls, corpus, query):
|
19 |
+
if not isinstance(corpus, list):
|
20 |
+
raise TypeError(
|
21 |
+
"Single-document summarization requires corpus of `List[str]`."
|
22 |
+
)
|
23 |
+
if not all([isinstance(ins, str) for ins in corpus]):
|
24 |
+
raise TypeError(
|
25 |
+
"Single-document summarization requires corpus of `List[str]`."
|
26 |
+
)
|
27 |
+
|
28 |
+
if query is not None:
|
29 |
+
if not isinstance(query, list):
|
30 |
+
raise TypeError(
|
31 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
32 |
+
)
|
33 |
+
if not all([isinstance(q, str) for q in query]):
|
34 |
+
raise TypeError(
|
35 |
+
"Query-based single-document summarization requires query of `List[str]`."
|
36 |
+
)
|