Spaces:
Running
Running
fix style
Browse files- tools/train/train.py +1 -1
tools/train/train.py
CHANGED
@@ -33,6 +33,7 @@ import jax.numpy as jnp
|
|
33 |
import numpy as np
|
34 |
import optax
|
35 |
import transformers
|
|
|
36 |
from datasets import Dataset
|
37 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
38 |
from flax.core.frozen_dict import freeze
|
@@ -44,7 +45,6 @@ from jax.experimental.pjit import pjit
|
|
44 |
from tqdm import tqdm
|
45 |
from transformers import HfArgumentParser
|
46 |
|
47 |
-
import wandb
|
48 |
from dalle_mini.data import Dataset
|
49 |
from dalle_mini.model import (
|
50 |
DalleBart,
|
|
|
33 |
import numpy as np
|
34 |
import optax
|
35 |
import transformers
|
36 |
+
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
from flax.core.frozen_dict import freeze
|
|
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
47 |
|
|
|
48 |
from dalle_mini.data import Dataset
|
49 |
from dalle_mini.model import (
|
50 |
DalleBart,
|