File size: 2,457 Bytes
aafe4a6 c14b261 eb40866 aafe4a6 c14b261 aafe4a6 c14b261 aafe4a6 c14b261 aafe4a6 c14b261 aafe4a6 eb40866 aafe4a6 eb40866 aafe4a6 eb40866 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import io
from PIL import Image
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, AutoConfig
# Load the model and processors
model_name = "colt12/maxcushion"
try:
print("Loading model configuration...")
config = AutoConfig.from_pretrained(model_name)
print("Loading model...")
if isinstance(config, VisionEncoderDecoderModel):
model = VisionEncoderDecoderModel.from_pretrained(model_name, config=config)
else:
# If the config is not for VisionEncoderDecoderModel, we might need to construct it manually
encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k")
decoder_config = AutoConfig.from_pretrained("gpt2")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k",
"gpt2",
encoder_config=encoder_config,
decoder_config=decoder_config
)
model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin"))
print("Model loaded successfully.")
print("Loading image processor...")
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
print("Image processor loaded successfully.")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
print("Tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or processors: {str(e)}")
raise
def predict(image_bytes):
# Open the image using PIL
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Preprocess the image
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
# Generate the caption
with torch.no_grad():
output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1)
generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_caption
def inference(inputs):
# Check if the input is a file or raw bytes
if "file" in inputs:
image = inputs["file"]
image_bytes = image.read()
elif "bytes" in inputs:
image_bytes = inputs["bytes"]
else:
raise ValueError("No valid input found. Expected 'file' or 'bytes'.")
# Generate the caption
result = predict(image_bytes)
# Return the result
return {"caption": result} |