Spaces:
Running
Running
feat: model config not hardcoded
Browse filesFormer-commit-id: 8cc773f8dfaee95469a926d907c006873922e1c6
- seq2seq/run_seq2seq_flax.py +12 -5
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
|
|
271 |
|
272 |
class CustomFlaxBartModule(FlaxBartModule):
|
273 |
def setup(self):
|
|
|
|
|
|
|
|
|
274 |
# we keep shared to easily load pre-trained weights
|
275 |
self.shared = nn.Embed(
|
276 |
self.config.vocab_size,
|
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
280 |
)
|
281 |
# a separate embedding is used for the decoder
|
282 |
self.decoder_embed = nn.Embed(
|
283 |
-
|
284 |
self.config.d_model,
|
285 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
286 |
dtype=self.dtype,
|
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
289 |
|
290 |
# the decoder has a different config
|
291 |
decoder_config = BartConfig(self.config.to_dict())
|
292 |
-
decoder_config.max_position_embeddings =
|
293 |
-
decoder_config.vocab_size =
|
294 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
295 |
|
296 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
297 |
def setup(self):
|
|
|
|
|
|
|
298 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
299 |
self.lm_head = nn.Dense(
|
300 |
-
|
301 |
use_bias=False,
|
302 |
dtype=self.dtype,
|
303 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
304 |
)
|
305 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1,
|
306 |
|
307 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
308 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
|
271 |
|
272 |
class CustomFlaxBartModule(FlaxBartModule):
|
273 |
def setup(self):
|
274 |
+
# check config is valid, otherwise set default values
|
275 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
276 |
+
self.config.max_position_embeddings_decoder = getattr(self.config, 'vocab_size_output', OUTPUT_LENGTH)
|
277 |
+
|
278 |
# we keep shared to easily load pre-trained weights
|
279 |
self.shared = nn.Embed(
|
280 |
self.config.vocab_size,
|
|
|
284 |
)
|
285 |
# a separate embedding is used for the decoder
|
286 |
self.decoder_embed = nn.Embed(
|
287 |
+
self.config.vocab_size_output,
|
288 |
self.config.d_model,
|
289 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
290 |
dtype=self.dtype,
|
|
|
293 |
|
294 |
# the decoder has a different config
|
295 |
decoder_config = BartConfig(self.config.to_dict())
|
296 |
+
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
|
297 |
+
decoder_config.vocab_size = self.config.vocab_size_output
|
298 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
299 |
|
300 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
301 |
def setup(self):
|
302 |
+
# check config is valid, otherwise set default values
|
303 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
304 |
+
|
305 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
306 |
self.lm_head = nn.Dense(
|
307 |
+
self.config.vocab_size_output,
|
308 |
use_bias=False,
|
309 |
dtype=self.dtype,
|
310 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
311 |
)
|
312 |
+
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
|
313 |
|
314 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
315 |
module_class = CustomFlaxBartForConditionalGenerationModule
|