colt12 commited on
Commit
eb40866
·
verified ·
1 Parent(s): 1cc2181

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -29
handler.py CHANGED
@@ -1,60 +1,50 @@
1
  import io
2
  from PIL import Image
3
  import torch
4
- from flask import jsonify
5
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
 
6
  # Load the model and processors
7
  model_name = "colt12/maxcushion"
8
  try:
9
  print("Loading model...")
10
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
11
  print("Model loaded successfully.")
12
-
13
  print("Loading image processor...")
14
  image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
15
  print("Image processor loaded successfully.")
16
-
17
  print("Loading tokenizer...")
18
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
19
  print("Tokenizer loaded successfully.")
20
  except Exception as e:
21
  print(f"Error loading model or processors: {str(e)}")
22
  raise
 
23
  def predict(image_bytes):
24
  # Open the image using PIL
25
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
26
 
27
  # Preprocess the image
28
  pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
 
29
  # Generate the caption
30
  with torch.no_grad():
31
  output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1)
32
  generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
 
34
  return generated_caption
35
- def handler(event, context):
36
- try:
37
- if 'body' not in event or not event['body']:
38
- return {
39
- 'statusCode': 400,
40
- 'body': jsonify({"error": "No image provided"})
41
- }
42
-
43
- image_bytes = event['body'].encode('utf-8')
44
- result = predict(image_bytes)
45
- return {
46
- 'statusCode': 200,
47
- 'body': jsonify({"caption": result})
48
- }
49
- except Exception as e:
50
- return {
51
- 'statusCode': 500,
52
- 'body': jsonify({"error": str(e)})
53
- }
54
- # For local testing
55
- if __name__ == "__main__":
56
- # Simulate an event
57
- test_event = {
58
- 'body': 'path/to/your/test/image.jpg' # Replace with an actual image path for testing
59
- }
60
- print(handler(test_event, None))
 
1
  import io
2
  from PIL import Image
3
  import torch
 
4
  from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
+
6
  # Load the model and processors
7
  model_name = "colt12/maxcushion"
8
  try:
9
  print("Loading model...")
10
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
11
  print("Model loaded successfully.")
 
12
  print("Loading image processor...")
13
  image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
14
  print("Image processor loaded successfully.")
 
15
  print("Loading tokenizer...")
16
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
17
  print("Tokenizer loaded successfully.")
18
  except Exception as e:
19
  print(f"Error loading model or processors: {str(e)}")
20
  raise
21
+
22
  def predict(image_bytes):
23
  # Open the image using PIL
24
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
25
 
26
  # Preprocess the image
27
  pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
28
+
29
  # Generate the caption
30
  with torch.no_grad():
31
  output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1)
32
  generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
33
 
34
  return generated_caption
35
+
36
+ def inference(inputs):
37
+ # Check if the input is a file or raw bytes
38
+ if "file" in inputs:
39
+ image = inputs["file"]
40
+ image_bytes = image.read()
41
+ elif "bytes" in inputs:
42
+ image_bytes = inputs["bytes"]
43
+ else:
44
+ raise ValueError("No valid input found. Expected 'file' or 'bytes'.")
45
+
46
+ # Generate the caption
47
+ result = predict(image_bytes)
48
+
49
+ # Return the result
50
+ return {"caption": result}