Update kraken_model/modeling_kraken.py
Browse files
kraken_model/modeling_kraken.py
CHANGED
@@ -40,11 +40,6 @@ class KrakenForCausalLM(PreTrainedModel):
|
|
40 |
model_decision_index = self.models_indices[prediction]
|
41 |
model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
|
42 |
return model_keys[model_decision_index]
|
43 |
-
|
44 |
-
def expert_tokenizer(self, text):
|
45 |
-
model_key = self.determine_model(text)
|
46 |
-
return self.tokenizers[model_key]
|
47 |
-
|
48 |
|
49 |
def generate(self, input_ids, **generate_kwargs):
|
50 |
# Tokenize the input_ids
|
@@ -75,8 +70,16 @@ class KrakenForCausalLM(PreTrainedModel):
|
|
75 |
tok_input_ids = tok.input_ids.to(current_device)
|
76 |
tok_attention_mask = tok.attention_mask.to(current_device)
|
77 |
|
78 |
-
# Generate text using the
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
|
|
|
40 |
model_decision_index = self.models_indices[prediction]
|
41 |
model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
|
42 |
return model_keys[model_decision_index]
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def generate(self, input_ids, **generate_kwargs):
|
45 |
# Tokenize the input_ids
|
|
|
70 |
tok_input_ids = tok.input_ids.to(current_device)
|
71 |
tok_attention_mask = tok.attention_mask.to(current_device)
|
72 |
|
73 |
+
# Generate text using the modified model
|
74 |
+
output_ids = model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
|
75 |
+
|
76 |
+
# Decode the output using the expert tokenizer
|
77 |
+
decoded_text = self.tokenizers[model_key].decode(output_ids[0], skip_special_tokens=True)
|
78 |
+
|
79 |
+
# Retokenize the decoded text using the base tokenizer for external compatibility
|
80 |
+
retokenized_ids = self.tokenizer(decoded_text, return_tensors="pt").input_ids.to(current_device)
|
81 |
+
|
82 |
+
return retokenized_ids
|
83 |
|
84 |
|
85 |
|