wsntxxn commited on
Commit
15c646a
1 Parent(s): 9f7d1f8

Upload model

Browse files
Files changed (1) hide show
  1. hf_modeling_grounding.py +3 -33
hf_modeling_grounding.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torchaudio import transforms
9
  from transformers import PreTrainedModel, PretrainedConfig
 
10
 
11
 
12
  def sum_with_lens(features, lens):
@@ -256,37 +257,6 @@ class DotProduct(nn.Module):
256
  return score
257
 
258
 
259
- class Vocabulary(object):
260
- """Simple vocabulary wrapper."""
261
-
262
- def __init__(self):
263
- self.word2idx = {}
264
- self.idx2word = {}
265
- self.idx = 0
266
-
267
- def add_word(self, word):
268
- if not word in self.word2idx:
269
- self.word2idx[word] = self.idx
270
- self.idx2word[self.idx] = word
271
- self.idx += 1
272
-
273
- def __call__(self, word):
274
- if not word in self.word2idx:
275
- return self.word2idx["<unk>"]
276
- return self.word2idx[word]
277
-
278
- def __len__(self):
279
- return len(self.word2idx)
280
-
281
- def state_dict(self):
282
- return self.word2idx
283
-
284
- def load_state_dict(self, state_dict):
285
- self.word2idx = state_dict
286
- self.idx2word = {idx: word for word, idx in self.word2idx.items()}
287
- self.idx = len(self.word2idx)
288
-
289
-
290
  class BiEncoder(nn.Module):
291
 
292
  def __init__(self,
@@ -425,6 +395,6 @@ class Cnn8RnnW2vMeanGroundingModel(PreTrainedModel):
425
  **kwargs):
426
  model = super().from_pretrained(pretrained_model_name_or_path,
427
  *model_args, **kwargs)
428
- model.vocab_mapping = json.load(
429
- open(os.path.join(pretrained_model_name_or_path, "vocab.json")))
430
  return model
 
7
  import torch.nn.functional as F
8
  from torchaudio import transforms
9
  from transformers import PreTrainedModel, PretrainedConfig
10
+ from transformers.utils.hub import cached_file
11
 
12
 
13
  def sum_with_lens(features, lens):
 
257
  return score
258
 
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  class BiEncoder(nn.Module):
261
 
262
  def __init__(self,
 
395
  **kwargs):
396
  model = super().from_pretrained(pretrained_model_name_or_path,
397
  *model_args, **kwargs)
398
+ vocab_path = cached_file(pretrained_model_name_or_path, "vocab.json")
399
+ model.vocab_mapping = json.load(open(vocab_path))
400
  return model