boris commited on
Commit
d483294
1 Parent(s): 0a691de

fix: style

Browse files
Files changed (1) hide show
  1. tools/train/distributed_shampoo.py +2 -2
tools/train/distributed_shampoo.py CHANGED
@@ -36,13 +36,13 @@ import itertools
36
  from typing import Any, List, NamedTuple
37
 
38
  import chex
39
- from flax import struct
40
  import jax
41
- from jax import lax
42
  import jax.experimental.pjit as pjit
43
  import jax.numpy as jnp
44
  import numpy as np
45
  import optax
 
 
46
 
47
 
48
  # pylint:disable=no-value-for-parameter
 
36
  from typing import Any, List, NamedTuple
37
 
38
  import chex
 
39
  import jax
 
40
  import jax.experimental.pjit as pjit
41
  import jax.numpy as jnp
42
  import numpy as np
43
  import optax
44
+ from flax import struct
45
+ from jax import lax
46
 
47
 
48
  # pylint:disable=no-value-for-parameter