greco commited on
Commit
c596ec3
1 Parent(s): ff5349b
.gitattributes CHANGED
@@ -29,3 +29,4 @@ models/bertopic_model_tokyo_olympics_tweets filter=lfs diff=lfs merge=lfs -text
29
  models/bertopic_model_tokyo_olympics_tweets_unclean filter=lfs diff=lfs merge=lfs -text
30
  models/distilbart-mnli-12-1/flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
31
  models/distilbart-mnli-12-1/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
 
 
29
  models/bertopic_model_tokyo_olympics_tweets_unclean filter=lfs diff=lfs merge=lfs -text
30
  models/distilbart-mnli-12-1/flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
31
  models/distilbart-mnli-12-1/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
32
+ models/distilbert-base-uncased-finetuned-sst-2-english/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
@@ -127,3 +132,5 @@ dmypy.json
127
 
128
  # Pyre type checker
129
  .pyre/
 
 
 
1
+ # custom
2
+ survey_analytics.ipynb
3
+ embeddings_unclean.pickle
4
+ embeddings.pickle
5
+
6
  # Byte-compiled / optimized / DLL files
7
  __pycache__/
8
  *.py[cod]
 
132
 
133
  # Pyre type checker
134
  .pyre/
135
+ survey_analytics.ipynb
136
+ survey_analytics.ipynb
app.py CHANGED
@@ -22,14 +22,14 @@ from bertopic import BERTopic
22
  # custom
23
  import survey_analytics_library as LIB
24
 
25
- st.set_page_config(layout='wide')
26
 
27
  # define data file path
28
  data_path = 'data' + os.sep
29
  # define model file path
30
  model_path = 'models' + os.sep
31
 
32
- # load all data and models
33
  @st.cache
34
  def read_survey_data():
35
  data_survey = pd.read_csv(data_path+'bfi_sample_answers.csv')
@@ -77,7 +77,7 @@ st.write('''
77
  st.write('\n')
78
  st.write('\n')
79
 
80
- # copy daya
81
  df_factor_analysis = data_survey.copy()
82
 
83
  st.subheader('Sample Survey Data')
@@ -87,7 +87,7 @@ st.write('''
87
  ''')
88
 
89
  # split page into two columns
90
- # display survey questions and responses as dataframes
91
  col1, col2 = st.columns(2)
92
  with col1:
93
  st.write('Survey Questions')
@@ -106,16 +106,18 @@ st.write('''
106
 
107
  # interactive button to run statistical test to determine suitability for factor analysis
108
  if st.button('Run Tests'):
109
- # Test with the null hypothesis that the correlation matrix is an identity matrix
110
  bartlett_sphericity_stat, p_value = calculate_bartlett_sphericity(x=df_factor_analysis)
111
- # Test how predictable of a variable by others
112
  kmo_per_variable, kmo_total = calculate_kmo(x=df_factor_analysis)
 
113
  st.write(f'''
114
  The P Value from Bartlett\'s Test (suitability is less than 0.05): **{round(p_value, 2)}**
115
  The Value from KMO Test (suitability is more than 0.60): **{round(kmo_total, 2)}**
116
  ''')
 
117
  fa_stat_test = 'Failed'
118
-
119
  if p_value < 0.05 and kmo_total >= 0.6:
120
  fa_stat_test = 'Passed'
121
 
@@ -160,9 +162,13 @@ st.write(f'''
160
  Kaiser criterion is one of many guides to determine the number of factors, ultimately the decision on the number of factors to use is best decided by the user based on their use case.
161
  ''')
162
 
 
163
  with st.form('num_factor_form'):
 
164
  user_num_factors = st.number_input('Enter desired number of factors:', min_value=1, max_value=10, value=6)
 
165
  optimal_factors = user_num_factors
 
166
  submit = st.form_submit_button('Run Factor Analysis')
167
 
168
  st.write('\n')
@@ -175,14 +181,16 @@ fa.fit(df_factor_analysis)
175
  # generate factor loadings
176
  loads_df = pd.DataFrame(fa.loadings_, index=df_factor_analysis.columns)
177
 
178
- transformed_df = fa.fit_transform(df_factor_analysis)
179
- transformed_df = pd.DataFrame(transformed_df)
180
- transformed_df.columns = ['factor_'+str(col) for col in list(transformed_df)]
181
-
182
- responder_factors = transformed_df.copy()
 
 
183
  responder_factors['cluster'] = responder_factors.apply(lambda s: s.argmax(), axis=1)
184
 
185
- # list of factor columns
186
  list_of_factor_cols = [col for col in responder_factors.columns if 'factor_' in col]
187
  st.subheader('Fator Analysis Results')
188
  st.write('''
@@ -193,30 +201,36 @@ st.write('''
193
  st.dataframe(responder_factors.style.highlight_max(axis=1, subset=list_of_factor_cols, props='color:white; background-color:green;').format(precision=2))
194
  st.write('\n')
195
 
 
196
  fa_clusters = df_factor_analysis.copy().reset_index(drop=True)
197
  fa_clusters['cluster'] = responder_factors['cluster']
 
 
 
 
198
  fa_z_scores = df_factor_analysis.copy().reset_index(drop=True)
199
  fa_z_scores = fa_z_scores.apply(zscore)
200
  fa_z_scores['cluster'] = responder_factors['cluster']
201
  fa_z_scores = fa_z_scores.groupby('cluster').mean().reset_index()
202
  fa_z_scores = fa_z_scores.apply(lambda x: round(x, 2))
203
 
204
- cm = sns.light_palette('green', as_cmap=True)
205
- list_of_question_cols = list(fa_z_scores.iloc[:,1:])
206
  st.write('''
207
  Aggregating the scores of the clusters gives us detail insights to the personality traits of the responders.
208
  The scores here have been normalised to Z-scores, a measure of how many standard deviations (SD) is the score away from the mean.
209
  E.g. A Z-score of 0 indicates the score is identical to the mean, while a Z-score of 1 indicates the score is 1 SD away from the mean.
210
  ''')
 
 
 
 
 
211
  st.dataframe(fa_z_scores.style.background_gradient(cmap=cm, subset=list_of_question_cols).format(precision=2))
212
  st.write('\n')
213
 
214
- cluster_counts = fa_clusters['cluster'].value_counts().reset_index()
215
- cluster_counts = cluster_counts.rename(columns={'index':'Cluster', 'cluster':'Count'})
216
-
217
  st.write('''
218
  Lastly, we can visualise the distribution of responders in each cluster.
219
  ''')
 
220
  fig = px.pie(
221
  cluster_counts,
222
  values='Count',
@@ -243,9 +257,12 @@ st.write('''
243
  ''')
244
  st.write('\n')
245
 
246
- st.write('''
247
- Here we have 10,000 tweets from the Tokyo Olympics, going through them manually and coming up with topics would not be practical.
248
  ''')
 
 
 
249
  st.dataframe(tokyo)
250
  st.write('\n')
251
  st.write('\n')
@@ -255,6 +272,7 @@ st.write('''
255
  ''')
256
  st.write('\n')
257
 
 
258
  fig = LIB.visualize_barchart_titles(
259
  topic_model=topic_model_unclean,
260
  subplot_titles=None,
@@ -265,23 +283,25 @@ fig = LIB.visualize_barchart_titles(
265
  st.plotly_chart(fig, use_container_width=True)
266
 
267
  st.write('''
268
- From the chart above, we can see that 'Topic 1' and 'Topic 3' have some words that are not as meaningful.
269
- For 'Topic 1', we already know that the tweets are about the Tokyo 2020 Olympics, having a topic for that isn't helpful.
270
- 'Tokyo', '2020', etc., we refer to these as *stopwords*, and lets remove them and regenerate the topics.
271
  ''')
272
  st.write('\n')
273
 
 
274
  labelled_topics = [
275
- 'Mirabai Chanu (Indian Weightlifter)',
276
- 'Hockey',
277
- 'Barbra Banda (Zambian Football Player)',
278
  'Sutirtha Mukherjee (Indian Table Tennis Player)',
279
- 'Vikas Krishan (Indian Boxer)',
280
  'Road Race',
281
- 'Brendon Smith (Australian Swimmer)',
282
  'Sam Kerr (Australian Footballer)',
 
283
  ]
284
 
 
285
  fig = LIB.visualize_barchart_titles(
286
  topic_model=topic_model,
287
  subplot_titles=labelled_topics,
@@ -290,6 +310,7 @@ fig = LIB.visualize_barchart_titles(
290
  height=300
291
  )
292
  st.plotly_chart(fig, use_container_width=True)
 
293
  st.write('''
294
  Now we can see that the topics have improved.
295
  We can make use of the top words in each topic to come up with a meaningful name.
@@ -297,37 +318,57 @@ st.write('''
297
  st.write('\n')
298
  st.write('\n')
299
 
 
300
  topics_df = topic_model.get_topic_info()
301
 
302
  st.write(f'''
303
  Next, we can also review the total number of topics and how many tweets are in each topic, to give us a sense of importance or priority.
304
  There are a total of **{len(topics_df)-1}** topics, and the larget topic contains **{topics_df['Count'][1]}** tweets.
305
- {topics_df['Count'][0]} tweets have also been assigned as Topic -1 or outliers.
306
- These tweets are more unique and there are enough of them to form a topic.
307
  ''')
 
308
  st.dataframe(topics_df)
309
  st.write('\n')
310
 
311
  st.write('''
312
- As there are many topics generated, we can also visualise how closely related they are to one another.
313
- Depending on the business case, we may want to merge these topics together or keep them separate.
314
- If there are too many or too few topics, there is also the option to tune the parameters of the model to refine the results.
315
- ''')
316
- fig = topic_model.visualize_topics()
317
- st.plotly_chart(fig, use_container_width=True)
318
- st.write('\n')
319
 
320
- st.write('''
321
- Lastly, we can inspect the individual tweets within each topic.
322
  ''')
323
 
 
 
 
 
 
324
  with st.form('inspect_tweets'):
325
- inspect_topic = st.number_input('Enter Topic (from -1 to 63) to Inspect:', min_value=-1, max_value=63, value=8)
326
  submit = st.form_submit_button('Inspect Topic')
327
 
 
328
  inspect_topic_words = [i[0] for i in topic_model.get_topic(inspect_topic)[:5]]
329
 
330
  st.write(f'''
331
  The top five words for Topic {inspect_topic} are: {inspect_topic_words}
332
  ''')
 
333
  st.dataframe(topic_results.loc[(topic_results['Topic'] == inspect_topic)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # custom
23
  import survey_analytics_library as LIB
24
 
25
+ # st.set_page_config(layout='wide')
26
 
27
  # define data file path
28
  data_path = 'data' + os.sep
29
  # define model file path
30
  model_path = 'models' + os.sep
31
 
32
+ # load and cache all data and models to improve app performance
33
  @st.cache
34
  def read_survey_data():
35
  data_survey = pd.read_csv(data_path+'bfi_sample_answers.csv')
 
77
  st.write('\n')
78
  st.write('\n')
79
 
80
+ # copy data
81
  df_factor_analysis = data_survey.copy()
82
 
83
  st.subheader('Sample Survey Data')
 
87
  ''')
88
 
89
  # split page into two columns
90
+ # display survey questions and answers as dataframes side by side
91
  col1, col2 = st.columns(2)
92
  with col1:
93
  st.write('Survey Questions')
 
106
 
107
  # interactive button to run statistical test to determine suitability for factor analysis
108
  if st.button('Run Tests'):
109
+ # test with the null hypothesis that the correlation matrix is an identity matrix
110
  bartlett_sphericity_stat, p_value = calculate_bartlett_sphericity(x=df_factor_analysis)
111
+ # test how predictable of a variable by others
112
  kmo_per_variable, kmo_total = calculate_kmo(x=df_factor_analysis)
113
+ # print test results
114
  st.write(f'''
115
  The P Value from Bartlett\'s Test (suitability is less than 0.05): **{round(p_value, 2)}**
116
  The Value from KMO Test (suitability is more than 0.60): **{round(kmo_total, 2)}**
117
  ''')
118
+ # set default status to 'Failed'
119
  fa_stat_test = 'Failed'
120
+ # check if data passes both tests
121
  if p_value < 0.05 and kmo_total >= 0.6:
122
  fa_stat_test = 'Passed'
123
 
 
162
  Kaiser criterion is one of many guides to determine the number of factors, ultimately the decision on the number of factors to use is best decided by the user based on their use case.
163
  ''')
164
 
165
+ # interactive form for user to enter different number of factors for analysis
166
  with st.form('num_factor_form'):
167
+ # define number input
168
  user_num_factors = st.number_input('Enter desired number of factors:', min_value=1, max_value=10, value=6)
169
+ # set factors to user input
170
  optimal_factors = user_num_factors
171
+ # submit button for form to rerun app when user is ready
172
  submit = st.form_submit_button('Run Factor Analysis')
173
 
174
  st.write('\n')
 
181
  # generate factor loadings
182
  loads_df = pd.DataFrame(fa.loadings_, index=df_factor_analysis.columns)
183
 
184
+ # fit and transform data
185
+ responder_factors = fa.fit_transform(df_factor_analysis)
186
+ # store results as df
187
+ responder_factors = pd.DataFrame(responder_factors)
188
+ # rename columns to 'factor_n'
189
+ responder_factors.columns = ['factor_'+str(col) for col in list(responder_factors)]
190
+ # use the max loading across all factors to determine a responder's cluster
191
  responder_factors['cluster'] = responder_factors.apply(lambda s: s.argmax(), axis=1)
192
 
193
+ # define list of factor columns
194
  list_of_factor_cols = [col for col in responder_factors.columns if 'factor_' in col]
195
  st.subheader('Fator Analysis Results')
196
  st.write('''
 
201
  st.dataframe(responder_factors.style.highlight_max(axis=1, subset=list_of_factor_cols, props='color:white; background-color:green;').format(precision=2))
202
  st.write('\n')
203
 
204
+ # count number of responders in each cluster
205
  fa_clusters = df_factor_analysis.copy().reset_index(drop=True)
206
  fa_clusters['cluster'] = responder_factors['cluster']
207
+ cluster_counts = fa_clusters['cluster'].value_counts().reset_index()
208
+ cluster_counts = cluster_counts.rename(columns={'index':'Cluster', 'cluster':'Count'})
209
+
210
+ # calculate z-scores for each cluster
211
  fa_z_scores = df_factor_analysis.copy().reset_index(drop=True)
212
  fa_z_scores = fa_z_scores.apply(zscore)
213
  fa_z_scores['cluster'] = responder_factors['cluster']
214
  fa_z_scores = fa_z_scores.groupby('cluster').mean().reset_index()
215
  fa_z_scores = fa_z_scores.apply(lambda x: round(x, 2))
216
 
 
 
217
  st.write('''
218
  Aggregating the scores of the clusters gives us detail insights to the personality traits of the responders.
219
  The scores here have been normalised to Z-scores, a measure of how many standard deviations (SD) is the score away from the mean.
220
  E.g. A Z-score of 0 indicates the score is identical to the mean, while a Z-score of 1 indicates the score is 1 SD away from the mean.
221
  ''')
222
+ # define colour map for highlighting cells
223
+ cm = sns.light_palette('green', as_cmap=True)
224
+ # define list of question columns
225
+ list_of_question_cols = list(fa_z_scores.iloc[:,1:])
226
+ # display z-scores of clusters with conditional formatting
227
  st.dataframe(fa_z_scores.style.background_gradient(cmap=cm, subset=list_of_question_cols).format(precision=2))
228
  st.write('\n')
229
 
 
 
 
230
  st.write('''
231
  Lastly, we can visualise the distribution of responders in each cluster.
232
  ''')
233
+ # plot percentage of responders in each cluster
234
  fig = px.pie(
235
  cluster_counts,
236
  values='Count',
 
257
  ''')
258
  st.write('\n')
259
 
260
+ st.write(f'''
261
+ Here we have {len(tokyo):,} tweets from the Tokyo Olympics, going through them manually and coming up with topics would not be practical.
262
  ''')
263
+ # rename column
264
+ tokyo = tokyo.rename(columns={'text':'Tweet'})
265
+ # display raw tweets
266
  st.dataframe(tokyo)
267
  st.write('\n')
268
  st.write('\n')
 
272
  ''')
273
  st.write('\n')
274
 
275
+ # plot topics using unclean data
276
  fig = LIB.visualize_barchart_titles(
277
  topic_model=topic_model_unclean,
278
  subplot_titles=None,
 
283
  st.plotly_chart(fig, use_container_width=True)
284
 
285
  st.write('''
286
+ From the chart above, we can see that 'Topic 0' and 'Topic 5' have some words that are not as meaningful.
287
+ For 'Topic 0', we already know that the tweets are about the Tokyo 2020 Olympics, having a topic for that isn't helpful.
288
+ 'Tokyo', '2020', 'Olympics', etc., we refer to these as *stopwords*, and lets remove them and regenerate the topics.
289
  ''')
290
  st.write('\n')
291
 
292
+ # define manually created topic labels
293
  labelled_topics = [
294
+ 'Barbra Banda (Zambian Footballer)',
295
+ 'Indian Pride',
 
296
  'Sutirtha Mukherjee (Indian Table Tennis Player)',
297
+ 'Mirabai Chanu (Indian Weightlifter)',
298
  'Road Race',
299
+ 'Japan Volleyball',
300
  'Sam Kerr (Australian Footballer)',
301
+ 'Vikas Krishan (Indian Boxer)',
302
  ]
303
 
304
+ # plot topics using clean data with stopwords removed
305
  fig = LIB.visualize_barchart_titles(
306
  topic_model=topic_model,
307
  subplot_titles=labelled_topics,
 
310
  height=300
311
  )
312
  st.plotly_chart(fig, use_container_width=True)
313
+
314
  st.write('''
315
  Now we can see that the topics have improved.
316
  We can make use of the top words in each topic to come up with a meaningful name.
 
318
  st.write('\n')
319
  st.write('\n')
320
 
321
+ # store topic info as dataframe
322
  topics_df = topic_model.get_topic_info()
323
 
324
  st.write(f'''
325
  Next, we can also review the total number of topics and how many tweets are in each topic, to give us a sense of importance or priority.
326
  There are a total of **{len(topics_df)-1}** topics, and the larget topic contains **{topics_df['Count'][1]}** tweets.
327
+ {topics_df['Count'][0]} tweets have also been assigned as Topic -1 or outliers. These tweets are unique compared to the others and there aren't enough of them to form a topic.
328
+ If there are too many or too few topics, there is also the option to further tune the model to refine the results.
329
  ''')
330
+ # display topic info
331
  st.dataframe(topics_df)
332
  st.write('\n')
333
 
334
  st.write('''
335
+ One point to also note is that the machine is not only picking out keywords in a tweet to determine its topic.
336
+ The model has an understanding of the relationship between words, e.g. 'Andy Murray' is related to 'tennis'.
337
+ For example:
338
+ *'Cilic vs Menezes, after more than 3 hours and millions of unconverted match points, is one of the worst quality ten…'*
339
+ This tweet is in the Topic 9 - Tennis without the word 'tennis' in it.
 
 
340
 
341
+ Here we can inspect the individual tweets within each topic.
 
342
  ''')
343
 
344
+ # define the first and last topic number
345
+ first_topic = topics_df['Topic'].iloc[0]
346
+ last_topic = topics_df['Topic'].iloc[-1]
347
+
348
+ # interative form for user to select a topic and inspect its top words and tweets
349
  with st.form('inspect_tweets'):
350
+ inspect_topic = st.number_input(f'Enter Topic (from {first_topic} to {last_topic}) to Inspect:', min_value=first_topic, max_value=last_topic, value=8)
351
  submit = st.form_submit_button('Inspect Topic')
352
 
353
+ # get top five words from list of tuples
354
  inspect_topic_words = [i[0] for i in topic_model.get_topic(inspect_topic)[:5]]
355
 
356
  st.write(f'''
357
  The top five words for Topic {inspect_topic} are: {inspect_topic_words}
358
  ''')
359
+ # display tweets from selected topic
360
  st.dataframe(topic_results.loc[(topic_results['Topic'] == inspect_topic)])
361
+ st.markdown('''---''')
362
+
363
+
364
+
365
+
366
+
367
+
368
+ st.header('Classifiying Text Responses and Sentiment Analysis')
369
+ st.write('''
370
+ With survey responses, sometimes as a business user, we already have an general idea of what responders are talking about and we want to categorise or classify the responses accordingly.
371
+ E.g.
372
+
373
+ ''')
374
+ st.write('\n')
data/ag_news.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/topic_results.csv CHANGED
The diff for this file is too large to render. See raw diff
 
models/bertopic_model_tokyo_olympics_tweets CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c6e3d236886eee8537b9259d44805006961088fe470253f4ab10f55836f87fa
3
- size 162358999
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:875ef8aa63fb501456e693214267492335a90db95b1e9cae092fd57c8ca787db
3
+ size 71005660
models/bertopic_model_tokyo_olympics_tweets_unclean CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:87c601adcf7e65d38fda982061a5c2e31e752126f9dd66f7abda6734a4ee2c84
3
- size 163381647
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0de856ed231c12e7baeaff15eb3159e1a5ef7c5512b459f915f46712f6d203a3
3
+ size 71961846
models/distilbart-mnli-12-1/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - mnli
4
+ tags:
5
+ - distilbart
6
+ - distilbart-mnli
7
+ pipeline_tag: zero-shot-classification
8
+ ---
9
+
10
+ # DistilBart-MNLI
11
+
12
+ distilbart-mnli is the distilled version of bart-large-mnli created using the **No Teacher Distillation** technique proposed for BART summarisation by Huggingface, [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq#distilbart).
13
+
14
+ We just copy alternating layers from `bart-large-mnli` and finetune more on the same data.
15
+
16
+
17
+ | | matched acc | mismatched acc |
18
+ | ------------------------------------------------------------------------------------ | ----------- | -------------- |
19
+ | [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli) (baseline, 12-12) | 89.9 | 90.01 |
20
+ | [distilbart-mnli-12-1](https://huggingface.co/valhalla/distilbart-mnli-12-1) | 87.08 | 87.5 |
21
+ | [distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) | 88.1 | 88.19 |
22
+ | [distilbart-mnli-12-6](https://huggingface.co/valhalla/distilbart-mnli-12-6) | 89.19 | 89.01 |
23
+ | [distilbart-mnli-12-9](https://huggingface.co/valhalla/distilbart-mnli-12-9) | 89.56 | 89.52 |
24
+
25
+
26
+ This is a very simple and effective technique, as we can see the performance drop is very little.
27
+
28
+ Detailed performace trade-offs will be posted in this [sheet](https://docs.google.com/spreadsheets/d/1dQeUvAKpScLuhDV1afaPJRRAE55s2LpIzDVA5xfqxvk/edit?usp=sharing).
29
+
30
+
31
+ ## Fine-tuning
32
+ If you want to train these models yourself, clone the [distillbart-mnli repo](https://github.com/patil-suraj/distillbart-mnli) and follow the steps below
33
+
34
+ Clone and install transformers from source
35
+ ```bash
36
+ git clone https://github.com/huggingface/transformers.git
37
+ pip install -qqq -U ./transformers
38
+ ```
39
+
40
+ Download MNLI data
41
+ ```bash
42
+ python transformers/utils/download_glue_data.py --data_dir glue_data --tasks MNLI
43
+ ```
44
+
45
+ Create student model
46
+ ```bash
47
+ python create_student.py \
48
+ --teacher_model_name_or_path facebook/bart-large-mnli \
49
+ --student_encoder_layers 12 \
50
+ --student_decoder_layers 6 \
51
+ --save_path student-bart-mnli-12-6 \
52
+ ```
53
+
54
+ Start fine-tuning
55
+ ```bash
56
+ python run_glue.py args.json
57
+ ```
58
+
59
+ You can find the logs of these trained models in this [wandb project](https://wandb.ai/psuraj/distilbart-mnli).
models/distilbart-mnli-12-1/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_num_labels": 3,
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "BartForSequenceClassification"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.0,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 1,
19
+ "decoder_start_token_id": 2,
20
+ "dropout": 0.1,
21
+ "encoder_attention_heads": 16,
22
+ "encoder_ffn_dim": 4096,
23
+ "encoder_layerdrop": 0.0,
24
+ "encoder_layers": 12,
25
+ "eos_token_id": 2,
26
+ "extra_pos_embeddings": 2,
27
+ "finetuning_task": "mnli",
28
+ "force_bos_token_to_be_generated": false,
29
+ "forced_eos_token_id": 2,
30
+ "gradient_checkpointing": false,
31
+ "id2label": {
32
+ "0": "contradiction",
33
+ "1": "neutral",
34
+ "2": "entailment"
35
+ },
36
+ "init_std": 0.02,
37
+ "is_encoder_decoder": true,
38
+ "label2id": {
39
+ "contradiction": 0,
40
+ "entailment": 2,
41
+ "neutral": 1
42
+ },
43
+ "max_position_embeddings": 1024,
44
+ "model_type": "bart",
45
+ "normalize_before": false,
46
+ "normalize_embedding": true,
47
+ "num_hidden_layers": 12,
48
+ "output_past": false,
49
+ "pad_token_id": 1,
50
+ "scale_embedding": false,
51
+ "static_position_embeddings": false,
52
+ "total_flos": 153130534133111808,
53
+ "transformers_version": "4.7.0.dev0",
54
+ "use_cache": true,
55
+ "vocab_size": 50265
56
+ }
models/distilbart-mnli-12-1/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/distilbart-mnli-12-1/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa79ff59084a5036b07a9cffeaa1b1b7c1aa5edeb1885416a734c001a09aa046
3
+ size 890410947
models/distilbart-mnli-12-1/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "sep_token": {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": {"content": "<pad>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "cls_token": {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true}}
models/distilbart-mnli-12-1/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 1024}
models/distilbart-mnli-12-1/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
models/distilbert-base-uncased-finetuned-sst-2-english/README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ datasets:
5
+ - sst-2
6
+ ---
7
+
8
+ # DistilBERT base uncased finetuned SST-2
9
+
10
+ This model is a fine-tune checkpoint of [DistilBERT-base-uncased](https://huggingface.co/distilbert-base-uncased), fine-tuned on SST-2.
11
+ This model reaches an accuracy of 91.3 on the dev set (for comparison, Bert bert-base-uncased version reaches an accuracy of 92.7).
12
+
13
+ For more details about DistilBERT, we encourage users to check out [this model card](https://huggingface.co/distilbert-base-uncased).
14
+
15
+ # Fine-tuning hyper-parameters
16
+
17
+ - learning_rate = 1e-5
18
+ - batch_size = 32
19
+ - warmup = 600
20
+ - max_seq_length = 128
21
+ - num_train_epochs = 3.0
22
+
23
+ # Bias
24
+
25
+ Based on a few experimentations, we observed that this model could produce biased predictions that target underrepresented populations.
26
+
27
+ For instance, for sentences like `This film was filmed in COUNTRY`, this binary classification model will give radically different probabilities for the positive label depending on the country (0.89 if the country is France, but 0.08 if the country is Afghanistan) when nothing in the input indicates such a strong semantic shift. In this [colab](https://colab.research.google.com/gist/ageron/fb2f64fb145b4bc7c49efc97e5f114d3/biasmap.ipynb), [Aurélien Géron](https://twitter.com/aureliengeron) made an interesting map plotting these probabilities for each country.
28
+
29
+ <img src="https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english/resolve/main/map.jpeg" alt="Map of positive probabilities per country." width="500"/>
30
+
31
+ We strongly advise users to thoroughly probe these aspects on their use-cases in order to evaluate the risks of this model. We recommend looking at the following bias evaluation datasets as a place to start: [WinoBias](https://huggingface.co/datasets/wino_bias), [WinoGender](https://huggingface.co/datasets/super_glue), [Stereoset](https://huggingface.co/datasets/stereoset).
models/distilbert-base-uncased-finetuned-sst-2-english/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "dim": 768,
8
+ "dropout": 0.1,
9
+ "finetuning_task": "sst-2",
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "NEGATIVE",
13
+ "1": "POSITIVE"
14
+ },
15
+ "initializer_range": 0.02,
16
+ "label2id": {
17
+ "NEGATIVE": 0,
18
+ "POSITIVE": 1
19
+ },
20
+ "max_position_embeddings": 512,
21
+ "model_type": "distilbert",
22
+ "n_heads": 12,
23
+ "n_layers": 6,
24
+ "output_past": true,
25
+ "pad_token_id": 0,
26
+ "qa_dropout": 0.1,
27
+ "seq_classif_dropout": 0.2,
28
+ "sinusoidal_pos_embds": false,
29
+ "tie_weights_": true,
30
+ "vocab_size": 30522
31
+ }
models/distilbert-base-uncased-finetuned-sst-2-english/map.jpeg ADDED
models/distilbert-base-uncased-finetuned-sst-2-english/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60554cbd7781b09d87f1ececbea8c064b94e49a7f03fd88e8775bfe6cc3d9f88
3
+ size 267844284
models/distilbert-base-uncased-finetuned-sst-2-english/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 512, "do_lower_case": true}
models/distilbert-base-uncased-finetuned-sst-2-english/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
survey_analytics.ipynb DELETED
The diff for this file is too large to render. See raw diff