RayCappola commited on
Commit
81d13c0
1 Parent(s): 7226095

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -18,9 +18,6 @@ class Net(nn.Module):
18
  return self.layer(x)
19
 
20
  def get_hidden_states(encoded, model):
21
- """Push input IDs through model. Stack and sum `layers` (last four by default).
22
- Select only those subword token outputs that belong to our word of interest
23
- and average them."""
24
  with torch.no_grad():
25
  output = model(decoder_input_ids=encoded['input_ids'], output_hidden_states=True, **encoded)
26
 
@@ -63,12 +60,12 @@ sum = 0
63
 
64
  res = ''
65
 
 
 
66
  for tag in best_tags:
67
  if sum > 0.95:
68
  break
69
  sum += logits[tag.item()]
70
  # print(tag.item())
71
  new_tag = labels_articles[tag.item() + 1]
72
- res += new_tag + '\n'
73
-
74
- st.write('best tags = \n', res)
 
18
  return self.layer(x)
19
 
20
  def get_hidden_states(encoded, model):
 
 
 
21
  with torch.no_grad():
22
  output = model(decoder_input_ids=encoded['input_ids'], output_hidden_states=True, **encoded)
23
 
 
60
 
61
  res = ''
62
 
63
+ st.write('best tags:')
64
+
65
  for tag in best_tags:
66
  if sum > 0.95:
67
  break
68
  sum += logits[tag.item()]
69
  # print(tag.item())
70
  new_tag = labels_articles[tag.item() + 1]
71
+ st.write(new_tag)