JustinLin610
commited on
Commit
•
0f35145
1
Parent(s):
c2b28d8
Update README.md
Browse filesreformat code of generator
README.md
CHANGED
@@ -41,8 +41,13 @@ After, refer the path to OFA-medium to `ckpt_dir`, and prepare an image for the
|
|
41 |
|
42 |
>>> # using the generator of fairseq version
|
43 |
>>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
|
44 |
-
>>> generator = sequence_generator.SequenceGenerator(
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
46 |
>>> data = {}
|
47 |
>>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
|
48 |
>>> gen_output = generator.generate([model], data)
|
|
|
41 |
|
42 |
>>> # using the generator of fairseq version
|
43 |
>>> model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
|
44 |
+
>>> generator = sequence_generator.SequenceGenerator(
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
beam_size=5,
|
47 |
+
max_len_b=16,
|
48 |
+
min_len=0,
|
49 |
+
no_repeat_ngram_size=3,
|
50 |
+
)
|
51 |
>>> data = {}
|
52 |
>>> data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
|
53 |
>>> gen_output = generator.generate([model], data)
|