Files changed (2) hide show
  1. custom_st.py +24 -6
  2. modules.json +3 -3
custom_st.py CHANGED
@@ -51,8 +51,8 @@ class Transformer(nn.Module):
51
  if config_args is None:
52
  config_args = {}
53
 
 
54
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
55
- self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
56
 
57
  self._lora_adaptations = self.config.lora_adaptations
58
  if (
@@ -66,6 +66,10 @@ class Transformer(nn.Module):
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
68
 
 
 
 
 
69
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
70
  tokenizer_args["model_max_length"] = max_seq_length
71
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -88,17 +92,31 @@ class Transformer(nn.Module):
88
  if tokenizer_name_or_path is not None:
89
  self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
90
 
91
- def forward(
92
- self, features: Dict[str, torch.Tensor], task: Optional[str] = None
93
- ) -> Dict[str, torch.Tensor]:
94
- """Returns token_embeddings, cls_token"""
 
 
 
 
 
 
 
 
95
  if task and task not in self._lora_adaptations:
96
  raise ValueError(
97
  f"Unsupported task '{task}'. "
98
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
99
  f"Alternatively, don't pass the `task` argument to disable LoRA."
100
  )
101
 
 
 
 
 
 
 
102
  adapter_mask = None
103
  if task:
104
  task_id = self._adaptation_map[task]
 
51
  if config_args is None:
52
  config_args = {}
53
 
54
+
55
  self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
 
56
 
57
  self._lora_adaptations = self.config.lora_adaptations
58
  if (
 
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
68
 
69
+ self.default_task = model_args.pop('default_task', None)
70
+
71
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
72
+
73
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
74
  tokenizer_args["model_max_length"] = max_seq_length
75
  self.tokenizer = AutoTokenizer.from_pretrained(
 
92
  if tokenizer_name_or_path is not None:
93
  self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
94
 
95
+
96
+ @property
97
+ def default_task(self):
98
+ return self._default_task
99
+
100
+ @default_task.setter
101
+ def default_task(self, task: Union[None, str]):
102
+ self._validate_task(task)
103
+ self._default_task = task
104
+
105
+
106
+ def _validate_task(self, task: str):
107
  if task and task not in self._lora_adaptations:
108
  raise ValueError(
109
  f"Unsupported task '{task}'. "
110
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}. "
111
  f"Alternatively, don't pass the `task` argument to disable LoRA."
112
  )
113
 
114
+ def forward(
115
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None
116
+ ) -> Dict[str, torch.Tensor]:
117
+ """Returns token_embeddings, cls_token"""
118
+ self._validate_task(task)
119
+ task = task or self.default_task
120
  adapter_mask = None
121
  if task:
122
  task_id = self._adaptation_map[task]
modules.json CHANGED
@@ -1,20 +1,20 @@
1
  [
2
  {
3
  "idx": 0,
4
- "name": "0",
5
  "path": "",
6
  "type": "custom_st.Transformer",
7
  "kwargs": ["task"]
8
  },
9
  {
10
  "idx": 1,
11
- "name": "1",
12
  "path": "1_Pooling",
13
  "type": "sentence_transformers.models.Pooling"
14
  },
15
  {
16
  "idx": 2,
17
- "name": "2",
18
  "path": "2_Normalize",
19
  "type": "sentence_transformers.models.Normalize"
20
  }
 
1
  [
2
  {
3
  "idx": 0,
4
+ "name": "transformer",
5
  "path": "",
6
  "type": "custom_st.Transformer",
7
  "kwargs": ["task"]
8
  },
9
  {
10
  "idx": 1,
11
+ "name": "pooler",
12
  "path": "1_Pooling",
13
  "type": "sentence_transformers.models.Pooling"
14
  },
15
  {
16
  "idx": 2,
17
+ "name": "normalizer",
18
  "path": "2_Normalize",
19
  "type": "sentence_transformers.models.Normalize"
20
  }