boris commited on
Commit
dc5ae57
2 Parent(s): dcbf091 8884d40

Merge pull request #15 from borisdayma/feat-fix-lr

Browse files
requirements.txt CHANGED
@@ -7,3 +7,6 @@ jax[tpu]>=0.2.16
7
  -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
  flax
9
  jupyter
 
 
 
 
7
  -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
  flax
9
  jupyter
10
+ # for logging
11
+ tensorboard
12
+ tetnsorflow
seq2seq/run_seq2seq_flax.py CHANGED
@@ -19,8 +19,11 @@ Script adapted from run_summarization_flax.py
19
  """
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
- import logging as pylogging # To avoid collision with transformers.utils.logging
23
  import os
 
 
 
 
24
  import sys
25
  import time
26
  from dataclasses import dataclass, field
@@ -673,12 +676,12 @@ def main():
673
  grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
674
  grads = jax.lax.pmean(grads, "batch")
675
  new_state = state.apply_gradients(
676
- grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
677
  )
678
  return new_state
679
 
680
  new_state = jax.lax.cond(
681
- state.step % training_args.gradient_accumulation_steps == 0,
682
  lambda _: update_fn(),
683
  lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
684
  None,
 
19
  """
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
 
22
  import os
23
+ # set a common huggingface cache folder (used with datasets and transformers)
24
+ os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
25
+
26
+ import logging as pylogging # To avoid collision with transformers.utils.logging
27
  import sys
28
  import time
29
  from dataclasses import dataclass, field
 
676
  grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
677
  grads = jax.lax.pmean(grads, "batch")
678
  new_state = state.apply_gradients(
679
+ grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
680
  )
681
  return new_state
682
 
683
  new_state = jax.lax.cond(
684
+ (state.step + 1) % training_args.gradient_accumulation_steps == 0,
685
  lambda _: update_fn(),
686
  lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
687
  None,
seq2seq/sweep.yaml CHANGED
@@ -8,9 +8,9 @@ metric:
8
  parameters:
9
  learning_rate:
10
  distribution: log_uniform
11
- # from exp(min) to exp(max), ie 1e-5 to 1e-3 on log scale
12
- min: -11.5
13
- max: -6.9
14
  gradient_accumulation_steps:
15
  value: 8
16
  warmup_steps:
 
8
  parameters:
9
  learning_rate:
10
  distribution: log_uniform
11
+ # from exp(min) to exp(max), ie 1e-4 to 5e-3 on log scale
12
+ min: -9.2
13
+ max: -5.3
14
  gradient_accumulation_steps:
15
  value: 8
16
  warmup_steps: