Spaces:
Running
Running
Merge pull request #42 from borisdayma/chore-clean
Browse fileschore: cleanup repo
Former-commit-id: 9977d1dc821ac8be7eef928e1aa6e2aaacd2c5f7
- README.md +7 -3
- dev/seq2seq/run_seq2seq_flax.py +6 -52
- img/logo.png +0 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Dalle Mini
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
@@ -12,13 +12,17 @@ pinned: false
|
|
12 |
|
13 |
_Generate images from a text prompt_
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
16 |
|
17 |
## Create my own images with the demo → Coming soon
|
18 |
|
19 |
## How does it work?
|
20 |
|
21 |
-
Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA
|
22 |
|
23 |
## Development
|
24 |
|
|
|
1 |
---
|
2 |
title: Dalle Mini
|
3 |
+
emoji: 🥑
|
4 |
colorFrom: red
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
12 |
|
13 |
_Generate images from a text prompt_
|
14 |
|
15 |
+
<img src="img/logo.png" width="200">
|
16 |
+
|
17 |
+
Our logo was generated with DALL-E mini by typing "logo of an armchair in the shape of an avocado".
|
18 |
+
|
19 |
+
You can also create your own pictures with the demo (TODO: add link).
|
20 |
|
21 |
## Create my own images with the demo → Coming soon
|
22 |
|
23 |
## How does it work?
|
24 |
|
25 |
+
Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
|
26 |
|
27 |
## Development
|
28 |
|
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -83,6 +83,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
83 |
|
84 |
|
85 |
# Model hyperparameters, for convenience
|
|
|
86 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
87 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
88 |
BOS_TOKEN_ID = 16384
|
@@ -217,7 +218,7 @@ class DataTrainingArguments:
|
|
217 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
218 |
)
|
219 |
predict_with_generate: bool = field(
|
220 |
-
default=False, metadata={"help": "Whether to use generate to calculate generative metrics
|
221 |
)
|
222 |
num_beams: Optional[int] = field(
|
223 |
default=None,
|
@@ -376,9 +377,6 @@ def main():
|
|
376 |
else:
|
377 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
378 |
|
379 |
-
logger.warning(f"WARNING: eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
380 |
-
training_args.eval_steps = 400
|
381 |
-
|
382 |
if (
|
383 |
os.path.exists(training_args.output_dir)
|
384 |
and os.listdir(training_args.output_dir)
|
@@ -425,11 +423,10 @@ def main():
|
|
425 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
426 |
#
|
427 |
data_files = {}
|
428 |
-
logger.warning(f"WARNING: Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
|
429 |
if data_args.train_file is not None:
|
430 |
-
data_files["train"] =
|
431 |
if data_args.validation_file is not None:
|
432 |
-
data_files["validation"] =
|
433 |
if data_args.test_file is not None:
|
434 |
data_files["test"] = data_args.test_file
|
435 |
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
@@ -608,35 +605,6 @@ def main():
|
|
608 |
desc="Running tokenizer on prediction dataset",
|
609 |
)
|
610 |
|
611 |
-
# Metric
|
612 |
-
#metric = load_metric("rouge")
|
613 |
-
|
614 |
-
def postprocess_text(preds, labels):
|
615 |
-
preds = [pred.strip() for pred in preds]
|
616 |
-
labels = [label.strip() for label in labels]
|
617 |
-
|
618 |
-
# rougeLSum expects newline after each sentence
|
619 |
-
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
620 |
-
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
621 |
-
|
622 |
-
return preds, labels
|
623 |
-
|
624 |
-
def compute_metrics(preds, labels):
|
625 |
-
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
626 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
627 |
-
|
628 |
-
# Some simple post-processing
|
629 |
-
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
630 |
-
|
631 |
-
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
632 |
-
# Extract a few results from ROUGE
|
633 |
-
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
634 |
-
|
635 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
636 |
-
result["gen_len"] = np.mean(prediction_lens)
|
637 |
-
result = {k: round(v, 4) for k, v in result.items()}
|
638 |
-
return result
|
639 |
-
|
640 |
# Initialize our training
|
641 |
rng = jax.random.PRNGKey(training_args.seed)
|
642 |
rng, dropout_rng = jax.random.split(rng)
|
@@ -822,15 +790,8 @@ def main():
|
|
822 |
# log metrics
|
823 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
824 |
|
825 |
-
# compute ROUGE metrics
|
826 |
-
rouge_desc = ""
|
827 |
-
# if data_args.predict_with_generate:
|
828 |
-
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
829 |
-
# eval_metrics.update(rouge_metrics)
|
830 |
-
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
831 |
-
|
832 |
# Print metrics and update progress bar
|
833 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']}
|
834 |
epochs.write(desc)
|
835 |
epochs.desc = desc
|
836 |
|
@@ -955,15 +916,8 @@ def main():
|
|
955 |
pred_metrics = get_metrics(pred_metrics)
|
956 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
957 |
|
958 |
-
# compute ROUGE metrics
|
959 |
-
rouge_desc = ""
|
960 |
-
if data_args.predict_with_generate:
|
961 |
-
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
962 |
-
pred_metrics.update(rouge_metrics)
|
963 |
-
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
964 |
-
|
965 |
# Print metrics
|
966 |
-
desc = f"Predict Loss: {pred_metrics['loss']}
|
967 |
logger.info(desc)
|
968 |
|
969 |
|
|
|
83 |
|
84 |
|
85 |
# Model hyperparameters, for convenience
|
86 |
+
# TODO: the model has now it's own definition file and should be imported
|
87 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
88 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
89 |
BOS_TOKEN_ID = 16384
|
|
|
218 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
219 |
)
|
220 |
predict_with_generate: bool = field(
|
221 |
+
default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
|
222 |
)
|
223 |
num_beams: Optional[int] = field(
|
224 |
default=None,
|
|
|
377 |
else:
|
378 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
379 |
|
|
|
|
|
|
|
380 |
if (
|
381 |
os.path.exists(training_args.output_dir)
|
382 |
and os.listdir(training_args.output_dir)
|
|
|
423 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
424 |
#
|
425 |
data_files = {}
|
|
|
426 |
if data_args.train_file is not None:
|
427 |
+
data_files["train"] = data_args.train_file
|
428 |
if data_args.validation_file is not None:
|
429 |
+
data_files["validation"] = data_args.validation_file
|
430 |
if data_args.test_file is not None:
|
431 |
data_files["test"] = data_args.test_file
|
432 |
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
|
|
605 |
desc="Running tokenizer on prediction dataset",
|
606 |
)
|
607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
# Initialize our training
|
609 |
rng = jax.random.PRNGKey(training_args.seed)
|
610 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
790 |
# log metrics
|
791 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
# Print metrics and update progress bar
|
794 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
795 |
epochs.write(desc)
|
796 |
epochs.desc = desc
|
797 |
|
|
|
916 |
pred_metrics = get_metrics(pred_metrics)
|
917 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
# Print metrics
|
920 |
+
desc = f"Predict Loss: {pred_metrics['loss']})"
|
921 |
logger.info(desc)
|
922 |
|
923 |
|
img/logo.png
ADDED