Spaces:
Running
Running
style
Browse files- dalle_mini/model/modeling.py +6 -7
- tools/train/train.py +1 -1
dalle_mini/model/modeling.py
CHANGED
@@ -30,21 +30,20 @@ from transformers.modeling_flax_outputs import (
|
|
30 |
FlaxSeq2SeqLMOutput,
|
31 |
)
|
32 |
from transformers.modeling_flax_utils import ACT2FN
|
33 |
-
from transformers.utils import logging
|
34 |
-
|
35 |
from transformers.models.bart.modeling_flax_bart import (
|
36 |
FlaxBartAttention,
|
37 |
-
|
38 |
FlaxBartDecoderLayer,
|
39 |
-
FlaxBartEncoderLayerCollection,
|
40 |
FlaxBartDecoderLayerCollection,
|
41 |
FlaxBartEncoder,
|
42 |
-
|
43 |
-
|
|
|
44 |
FlaxBartForConditionalGenerationModule,
|
|
|
45 |
FlaxBartPreTrainedModel,
|
46 |
-
FlaxBartForConditionalGeneration,
|
47 |
)
|
|
|
48 |
|
49 |
from .configuration import DalleBartConfig
|
50 |
|
|
|
30 |
FlaxSeq2SeqLMOutput,
|
31 |
)
|
32 |
from transformers.modeling_flax_utils import ACT2FN
|
|
|
|
|
33 |
from transformers.models.bart.modeling_flax_bart import (
|
34 |
FlaxBartAttention,
|
35 |
+
FlaxBartDecoder,
|
36 |
FlaxBartDecoderLayer,
|
|
|
37 |
FlaxBartDecoderLayerCollection,
|
38 |
FlaxBartEncoder,
|
39 |
+
FlaxBartEncoderLayer,
|
40 |
+
FlaxBartEncoderLayerCollection,
|
41 |
+
FlaxBartForConditionalGeneration,
|
42 |
FlaxBartForConditionalGenerationModule,
|
43 |
+
FlaxBartModule,
|
44 |
FlaxBartPreTrainedModel,
|
|
|
45 |
)
|
46 |
+
from transformers.utils import logging
|
47 |
|
48 |
from .configuration import DalleBartConfig
|
49 |
|
tools/train/train.py
CHANGED
@@ -43,7 +43,7 @@ from tqdm import tqdm
|
|
43 |
from transformers import AutoTokenizer, HfArgumentParser
|
44 |
|
45 |
from dalle_mini.data import Dataset
|
46 |
-
from dalle_mini.model import
|
47 |
|
48 |
logger = logging.getLogger(__name__)
|
49 |
|
|
|
43 |
from transformers import AutoTokenizer, HfArgumentParser
|
44 |
|
45 |
from dalle_mini.data import Dataset
|
46 |
+
from dalle_mini.model import DalleBart, DalleBartConfig
|
47 |
|
48 |
logger = logging.getLogger(__name__)
|
49 |
|