boris commited on
Commit
eb24dbc
1 Parent(s): 61c93f2

fix(config): set min/max for generation

Browse files
dalle_mini/model/configuration.py CHANGED
@@ -85,8 +85,6 @@ class DalleBartConfig(PretrainedConfig):
85
  self.scale_embedding = (
86
  scale_embedding # scale factor will be sqrt(d_model) if True
87
  )
88
- self.min_length = image_length + 1
89
- self.max_length = image_length + 1
90
 
91
  # remove inferred keys to prevent errors when loading config (passed as kwargs)
92
  for k in [
@@ -94,6 +92,8 @@ class DalleBartConfig(PretrainedConfig):
94
  "bos_token_id",
95
  "eos_token_id",
96
  "decoder_start_token_id",
 
 
97
  ]:
98
  kwargs.pop(k, None)
99
 
@@ -106,6 +106,8 @@ class DalleBartConfig(PretrainedConfig):
106
  decoder_start_token_id=image_vocab_size, # BOS appended to vocab
107
  forced_eos_token_id=forced_eos_token_id,
108
  tie_word_embeddings=tie_word_embeddings,
 
 
109
  **kwargs,
110
  )
111
 
 
85
  self.scale_embedding = (
86
  scale_embedding # scale factor will be sqrt(d_model) if True
87
  )
 
 
88
 
89
  # remove inferred keys to prevent errors when loading config (passed as kwargs)
90
  for k in [
 
92
  "bos_token_id",
93
  "eos_token_id",
94
  "decoder_start_token_id",
95
+ "min_length",
96
+ "max_length",
97
  ]:
98
  kwargs.pop(k, None)
99
 
 
106
  decoder_start_token_id=image_vocab_size, # BOS appended to vocab
107
  forced_eos_token_id=forced_eos_token_id,
108
  tie_word_embeddings=tie_word_embeddings,
109
+ min_length=image_length + 1,
110
+ max_length=image_length + 1,
111
  **kwargs,
112
  )
113
 
dalle_mini/model/modeling.py CHANGED
@@ -46,6 +46,8 @@ from transformers.models.bart.modeling_flax_bart import (
46
  FlaxBartForConditionalGeneration,
47
  )
48
 
 
 
49
  logger = logging.get_logger(__name__)
50
 
51
 
@@ -296,8 +298,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
296
  """
297
  Edits:
298
  - added num_params property
 
299
  """
300
 
 
 
301
  @property
302
  def num_params(self):
303
  num_params = jax.tree_map(
 
46
  FlaxBartForConditionalGeneration,
47
  )
48
 
49
+ from .configuration import DalleBartConfig
50
+
51
  logger = logging.get_logger(__name__)
52
 
53
 
 
298
  """
299
  Edits:
300
  - added num_params property
301
+ - config_class replaced to DalleBartConfig
302
  """
303
 
304
+ config_class = DalleBartConfig
305
+
306
  @property
307
  def num_params(self):
308
  num_params = jax.tree_map(