boris commited on
Commit
a6252c9
1 Parent(s): 5ee6e60
dalle_mini/model/modeling.py CHANGED
@@ -30,21 +30,20 @@ from transformers.modeling_flax_outputs import (
30
  FlaxSeq2SeqLMOutput,
31
  )
32
  from transformers.modeling_flax_utils import ACT2FN
33
- from transformers.utils import logging
34
-
35
  from transformers.models.bart.modeling_flax_bart import (
36
  FlaxBartAttention,
37
- FlaxBartEncoderLayer,
38
  FlaxBartDecoderLayer,
39
- FlaxBartEncoderLayerCollection,
40
  FlaxBartDecoderLayerCollection,
41
  FlaxBartEncoder,
42
- FlaxBartDecoder,
43
- FlaxBartModule,
 
44
  FlaxBartForConditionalGenerationModule,
 
45
  FlaxBartPreTrainedModel,
46
- FlaxBartForConditionalGeneration,
47
  )
 
48
 
49
  from .configuration import DalleBartConfig
50
 
 
30
  FlaxSeq2SeqLMOutput,
31
  )
32
  from transformers.modeling_flax_utils import ACT2FN
 
 
33
  from transformers.models.bart.modeling_flax_bart import (
34
  FlaxBartAttention,
35
+ FlaxBartDecoder,
36
  FlaxBartDecoderLayer,
 
37
  FlaxBartDecoderLayerCollection,
38
  FlaxBartEncoder,
39
+ FlaxBartEncoderLayer,
40
+ FlaxBartEncoderLayerCollection,
41
+ FlaxBartForConditionalGeneration,
42
  FlaxBartForConditionalGenerationModule,
43
+ FlaxBartModule,
44
  FlaxBartPreTrainedModel,
 
45
  )
46
+ from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
49
 
tools/train/train.py CHANGED
@@ -43,7 +43,7 @@ from tqdm import tqdm
43
  from transformers import AutoTokenizer, HfArgumentParser
44
 
45
  from dalle_mini.data import Dataset
46
- from dalle_mini.model import DalleBartConfig, DalleBart
47
 
48
  logger = logging.getLogger(__name__)
49
 
 
43
  from transformers import AutoTokenizer, HfArgumentParser
44
 
45
  from dalle_mini.data import Dataset
46
+ from dalle_mini.model import DalleBart, DalleBartConfig
47
 
48
  logger = logging.getLogger(__name__)
49