|
import io |
|
from PIL import Image |
|
import torch |
|
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, AutoConfig |
|
|
|
|
|
model_name = "colt12/maxcushion" |
|
try: |
|
print("Loading model configuration...") |
|
config = AutoConfig.from_pretrained(model_name) |
|
|
|
print("Loading model...") |
|
if isinstance(config, VisionEncoderDecoderModel): |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name, config=config) |
|
else: |
|
|
|
encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k") |
|
decoder_config = AutoConfig.from_pretrained("gpt2") |
|
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
|
"google/vit-base-patch16-224-in21k", |
|
"gpt2", |
|
encoder_config=encoder_config, |
|
decoder_config=decoder_config |
|
) |
|
model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin")) |
|
|
|
print("Model loaded successfully.") |
|
|
|
print("Loading image processor...") |
|
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
print("Image processor loaded successfully.") |
|
|
|
print("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
print("Tokenizer loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading model or processors: {str(e)}") |
|
raise |
|
|
|
def predict(image_bytes): |
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1) |
|
generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
return generated_caption |
|
|
|
def inference(inputs): |
|
|
|
if "file" in inputs: |
|
image = inputs["file"] |
|
image_bytes = image.read() |
|
elif "bytes" in inputs: |
|
image_bytes = inputs["bytes"] |
|
else: |
|
raise ValueError("No valid input found. Expected 'file' or 'bytes'.") |
|
|
|
|
|
result = predict(image_bytes) |
|
|
|
|
|
return {"caption": result} |