Spaces:
Running
Running
fix: style
Browse files
tools/train/distributed_shampoo.py
CHANGED
@@ -36,13 +36,13 @@ import itertools
|
|
36 |
from typing import Any, List, NamedTuple
|
37 |
|
38 |
import chex
|
39 |
-
from flax import struct
|
40 |
import jax
|
41 |
-
from jax import lax
|
42 |
import jax.experimental.pjit as pjit
|
43 |
import jax.numpy as jnp
|
44 |
import numpy as np
|
45 |
import optax
|
|
|
|
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|
|
|
36 |
from typing import Any, List, NamedTuple
|
37 |
|
38 |
import chex
|
|
|
39 |
import jax
|
|
|
40 |
import jax.experimental.pjit as pjit
|
41 |
import jax.numpy as jnp
|
42 |
import numpy as np
|
43 |
import optax
|
44 |
+
from flax import struct
|
45 |
+
from jax import lax
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|