Spaces:
Running
Running
fix(config): set min/max for generation
Browse files
dalle_mini/model/configuration.py
CHANGED
@@ -85,8 +85,6 @@ class DalleBartConfig(PretrainedConfig):
|
|
85 |
self.scale_embedding = (
|
86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
87 |
)
|
88 |
-
self.min_length = image_length + 1
|
89 |
-
self.max_length = image_length + 1
|
90 |
|
91 |
# remove inferred keys to prevent errors when loading config (passed as kwargs)
|
92 |
for k in [
|
@@ -94,6 +92,8 @@ class DalleBartConfig(PretrainedConfig):
|
|
94 |
"bos_token_id",
|
95 |
"eos_token_id",
|
96 |
"decoder_start_token_id",
|
|
|
|
|
97 |
]:
|
98 |
kwargs.pop(k, None)
|
99 |
|
@@ -106,6 +106,8 @@ class DalleBartConfig(PretrainedConfig):
|
|
106 |
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
107 |
forced_eos_token_id=forced_eos_token_id,
|
108 |
tie_word_embeddings=tie_word_embeddings,
|
|
|
|
|
109 |
**kwargs,
|
110 |
)
|
111 |
|
|
|
85 |
self.scale_embedding = (
|
86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
87 |
)
|
|
|
|
|
88 |
|
89 |
# remove inferred keys to prevent errors when loading config (passed as kwargs)
|
90 |
for k in [
|
|
|
92 |
"bos_token_id",
|
93 |
"eos_token_id",
|
94 |
"decoder_start_token_id",
|
95 |
+
"min_length",
|
96 |
+
"max_length",
|
97 |
]:
|
98 |
kwargs.pop(k, None)
|
99 |
|
|
|
106 |
decoder_start_token_id=image_vocab_size, # BOS appended to vocab
|
107 |
forced_eos_token_id=forced_eos_token_id,
|
108 |
tie_word_embeddings=tie_word_embeddings,
|
109 |
+
min_length=image_length + 1,
|
110 |
+
max_length=image_length + 1,
|
111 |
**kwargs,
|
112 |
)
|
113 |
|
dalle_mini/model/modeling.py
CHANGED
@@ -46,6 +46,8 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
46 |
FlaxBartForConditionalGeneration,
|
47 |
)
|
48 |
|
|
|
|
|
49 |
logger = logging.get_logger(__name__)
|
50 |
|
51 |
|
@@ -296,8 +298,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
296 |
"""
|
297 |
Edits:
|
298 |
- added num_params property
|
|
|
299 |
"""
|
300 |
|
|
|
|
|
301 |
@property
|
302 |
def num_params(self):
|
303 |
num_params = jax.tree_map(
|
|
|
46 |
FlaxBartForConditionalGeneration,
|
47 |
)
|
48 |
|
49 |
+
from .configuration import DalleBartConfig
|
50 |
+
|
51 |
logger = logging.get_logger(__name__)
|
52 |
|
53 |
|
|
|
298 |
"""
|
299 |
Edits:
|
300 |
- added num_params property
|
301 |
+
- config_class replaced to DalleBartConfig
|
302 |
"""
|
303 |
|
304 |
+
config_class = DalleBartConfig
|
305 |
+
|
306 |
@property
|
307 |
def num_params(self):
|
308 |
num_params = jax.tree_map(
|