boris commited on
Commit
0fe3e72
1 Parent(s): 85c1b8e

fix(data): minor bugs

Browse files
dalle_mini/data.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass, field
2
- from datasets import load_dataset
 
3
  import numpy as np
4
  import jax
5
  import jax.numpy as jnp
@@ -25,9 +26,9 @@ class Dataset:
25
  do_train: bool = False
26
  do_eval: bool = True
27
  seed_dataset: int = None
28
- train_dataset = field(init=False)
29
- eval_dataset = field(init=False)
30
- rng_dataset = field(init=False)
31
 
32
  def __post_init__(self):
33
  # define data_files
@@ -81,26 +82,21 @@ class Dataset:
81
  # normalize text
82
  if normalize_text:
83
  text_normalizer = TextNormalizer()
 
 
 
 
 
84
  for ds in ["train_dataset", "eval_dataset"]:
85
  if hasattr(self, ds):
86
  setattr(
87
  self,
88
  ds,
89
  (
90
- getattr(self, ds).map(
91
- normalize_text,
92
- fn_kwargs={
93
- "text_column": self.text_column,
94
- "text_normalizer": text_normalizer,
95
- },
96
- )
97
  if self.streaming
98
  else getattr(self, ds).map(
99
- normalize_text,
100
- fn_kwargs={
101
- "text_column": self.text_column,
102
- "text_normalizer": text_normalizer,
103
- },
104
  num_proc=self.preprocessing_num_workers,
105
  load_from_cache_file=not self.overwrite_cache,
106
  desc="Normalizing datasets",
@@ -109,6 +105,14 @@ class Dataset:
109
  )
110
 
111
  # preprocess
 
 
 
 
 
 
 
 
112
  for ds in ["train_dataset", "eval_dataset"]:
113
  if hasattr(self, ds):
114
  setattr(
@@ -116,27 +120,13 @@ class Dataset:
116
  ds,
117
  (
118
  getattr(self, ds).map(
119
- preprocess_function,
120
  batched=True,
121
- fn_kwargs={
122
- "tokenizer": tokenizer,
123
- "text_column": self.text_column,
124
- "encoding_column": self.encoding_column,
125
- "max_source_length": self.max_source_length,
126
- "decoder_start_token_id": decoder_start_token_id,
127
- },
128
  )
129
  if self.streaming
130
  else getattr(self, ds).map(
131
- preprocess_function,
132
  batched=True,
133
- fn_kwargs={
134
- "tokenizer": tokenizer,
135
- "text_column": self.text_column,
136
- "encoding_column": self.encoding_column,
137
- "max_source_length": self.max_source_length,
138
- "decoder_start_token_id": decoder_start_token_id,
139
- },
140
  remove_columns=getattr(ds, "column_names"),
141
  num_proc=self.preprocessing_num_workers,
142
  load_from_cache_file=not self.overwrite_cache,
@@ -230,7 +220,7 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
230
  return shifted_input_ids
231
 
232
 
233
- def normalize_text(example, text_column, text_normalizer):
234
  example[text_column] = text_normalizer(example[text_column])
235
  return example
236
 
 
1
  from dataclasses import dataclass, field
2
+ from datasets import load_dataset, Dataset
3
+ from functools import partial
4
  import numpy as np
5
  import jax
6
  import jax.numpy as jnp
 
26
  do_train: bool = False
27
  do_eval: bool = True
28
  seed_dataset: int = None
29
+ train_dataset: Dataset = field(init=False)
30
+ eval_dataset: Dataset = field(init=False)
31
+ rng_dataset: jnp.ndarray = field(init=False)
32
 
33
  def __post_init__(self):
34
  # define data_files
 
82
  # normalize text
83
  if normalize_text:
84
  text_normalizer = TextNormalizer()
85
+ partial_normalize_function = partial(
86
+ normalize_function,
87
+ text_column=self.text_column,
88
+ text_normalizer=text_normalizer,
89
+ )
90
  for ds in ["train_dataset", "eval_dataset"]:
91
  if hasattr(self, ds):
92
  setattr(
93
  self,
94
  ds,
95
  (
96
+ getattr(self, ds).map(partial_normalize_function)
 
 
 
 
 
 
97
  if self.streaming
98
  else getattr(self, ds).map(
99
+ partial_normalize_function,
 
 
 
 
100
  num_proc=self.preprocessing_num_workers,
101
  load_from_cache_file=not self.overwrite_cache,
102
  desc="Normalizing datasets",
 
105
  )
106
 
107
  # preprocess
108
+ partial_preprocess_function = partial(
109
+ preprocess_function,
110
+ tokenizer=tokenizer,
111
+ text_column=self.text_column,
112
+ encoding_column=self.encoding_column,
113
+ max_source_length=self.max_source_length,
114
+ decoder_start_token_id=decoder_start_token_id,
115
+ )
116
  for ds in ["train_dataset", "eval_dataset"]:
117
  if hasattr(self, ds):
118
  setattr(
 
120
  ds,
121
  (
122
  getattr(self, ds).map(
123
+ partial_preprocess_function,
124
  batched=True,
 
 
 
 
 
 
 
125
  )
126
  if self.streaming
127
  else getattr(self, ds).map(
128
+ partial_preprocess_function,
129
  batched=True,
 
 
 
 
 
 
 
130
  remove_columns=getattr(ds, "column_names"),
131
  num_proc=self.preprocessing_num_workers,
132
  load_from_cache_file=not self.overwrite_cache,
 
220
  return shifted_input_ids
221
 
222
 
223
+ def normalize_function(example, text_column, text_normalizer):
224
  example[text_column] = text_normalizer(example[text_column])
225
  return example
226
 
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -30,6 +30,7 @@ import json
30
  import datasets
31
  from datasets import Dataset
32
  from tqdm import tqdm
 
33
 
34
  import jax
35
  import jax.numpy as jnp
@@ -411,7 +412,9 @@ def main():
411
 
412
  # Load dataset
413
  dataset = Dataset(
414
- **data_args, do_train=training_args.do_train, do_eval=training_args.do_eval
 
 
415
  )
416
 
417
  # Set up wandb run
@@ -511,7 +514,7 @@ def main():
511
  # Preprocessing the datasets.
512
  # We need to normalize and tokenize inputs and targets.
513
 
514
- dataset = dataset.preprocess(
515
  tokenizer=tokenizer,
516
  decoder_start_token_id=model.config.decoder_start_token_id,
517
  normalize_text=model.config.normalize_text,
 
30
  import datasets
31
  from datasets import Dataset
32
  from tqdm import tqdm
33
+ from dataclasses import asdict
34
 
35
  import jax
36
  import jax.numpy as jnp
 
412
 
413
  # Load dataset
414
  dataset = Dataset(
415
+ **asdict(data_args),
416
+ do_train=training_args.do_train,
417
+ do_eval=training_args.do_eval,
418
  )
419
 
420
  # Set up wandb run
 
514
  # Preprocessing the datasets.
515
  # We need to normalize and tokenize inputs and targets.
516
 
517
+ dataset.preprocess(
518
  tokenizer=tokenizer,
519
  decoder_start_token_id=model.config.decoder_start_token_id,
520
  normalize_text=model.config.normalize_text,