Gladiator commited on
Commit
e348197
1 Parent(s): 4065f3f

cache bert models (extractive sum)

Browse files
extractive_summarizer/bert_parent.py CHANGED
@@ -1,13 +1,18 @@
1
  from typing import List, Union
2
 
3
- import numpy as np
4
  import torch
 
 
5
  from numpy import ndarray
6
  from transformers import (AlbertModel, AlbertTokenizer, BertModel,
7
  BertTokenizer, DistilBertModel, DistilBertTokenizer,
8
  PreTrainedModel, PreTrainedTokenizer, XLMModel,
9
  XLMTokenizer, XLNetModel, XLNetTokenizer)
10
 
 
 
 
 
11
 
12
  class BertParent(object):
13
  """
@@ -49,8 +54,9 @@ class BertParent(object):
49
  if custom_model:
50
  self.model = custom_model.to(self.device)
51
  else:
52
- self.model = base_model.from_pretrained(
53
- model, output_hidden_states=True).to(self.device)
 
54
 
55
  if custom_tokenizer:
56
  self.tokenizer = custom_tokenizer
@@ -59,6 +65,7 @@ class BertParent(object):
59
 
60
  self.model.eval()
61
 
 
62
  def tokenize_input(self, text: str) -> torch.tensor:
63
  """
64
  Tokenizes the text input.
 
1
  from typing import List, Union
2
 
 
3
  import torch
4
+ import streamlit as st
5
+ import numpy as np
6
  from numpy import ndarray
7
  from transformers import (AlbertModel, AlbertTokenizer, BertModel,
8
  BertTokenizer, DistilBertModel, DistilBertTokenizer,
9
  PreTrainedModel, PreTrainedTokenizer, XLMModel,
10
  XLMTokenizer, XLNetModel, XLNetTokenizer)
11
 
12
+ @st.cache()
13
+ def load_hf_model(base_model, model_name, device):
14
+ model = base_model.from_pretrained(model_name, output_hidden_states=True).to(device)
15
+ return model
16
 
17
  class BertParent(object):
18
  """
 
54
  if custom_model:
55
  self.model = custom_model.to(self.device)
56
  else:
57
+ # self.model = base_model.from_pretrained(
58
+ # model, output_hidden_states=True).to(self.device)
59
+ self.model = load_hf_model(base_model, model, self.device)
60
 
61
  if custom_tokenizer:
62
  self.tokenizer = custom_tokenizer
 
65
 
66
  self.model.eval()
67
 
68
+
69
  def tokenize_input(self, text: str) -> torch.tensor:
70
  """
71
  Tokenizes the text input.