Spaces:
Runtime error
Runtime error
File size: 2,436 Bytes
d79c24a 686f21e 8d9306e 686f21e d79c24a 8d9306e 686f21e 8d9306e a194253 e755009 8d9306e d79c24a 1a9cd94 d79c24a 686f21e 943681e e755009 8d9306e 943681e 8d9306e 943681e 8d9306e 943681e d28411b 9eed0db 943681e 9a6a97f 943681e 9a6a97f 943681e 9a6a97f 943681e 9a6a97f 8d9306e 9a6a97f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os, sys, shutil
import numpy as np
from PIL import Image
import jax
from transformers import ViTFeatureExtractor
from transformers import GPT2Tokenizer
from huggingface_hub import hf_hub_download
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
# Main model - ViTGPT2LM
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
# create target model directory
model_dir = './models/'
os.makedirs(model_dir, exist_ok=True)
# copy config file
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/config.json")
shutil.copyfile(filepath, os.path.join(model_dir, 'config.json'))
# copy model file
filepath = hf_hub_download("flax-community/vit-gpt2", "checkpoints/ckpt_5/flax_model.msgpack")
shutil.copyfile(filepath, os.path.join(model_dir, 'flax_model.msgpack'))
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_dir)
vit_model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
gpt2_model_name = 'asi/gpt-fr-cased-small'
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
max_length = 32
num_beams = 8
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
@jax.jit
def predict_fn(pixel_values):
return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
def predict(image):
# batch dim is added automatically
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
pixel_values = encoder_inputs.pixel_values
# generation
generation = predict_fn(pixel_values)
token_ids = np.array(generation.sequences)[0]
caption = tokenizer.decode(token_ids)
caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
caption.replace("à l'arrière-plan", '').replace("Une photo noire et blanche d'", '').replace("en arrière-plan", '')
while ' ' in caption:
caption = caption.replace(' ', ' ')
caption = caption.strip()
return caption
def compile():
image_path = 'samples/val_000000039769.jpg'
image = Image.open(image_path)
caption = predict(image)
image.close()
def predict_dummy(image):
return 'dummy caption!'
compile()
sample_dir = './samples/'
sample_fns = tuple([f"{int(f.replace('COCO_val2014_', '').replace('.jpg', ''))}.jpg" for f in os.listdir(sample_dir) if f.startswith('COCO_val2014_')])
|