|
import os
|
|
import json
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import transforms, models
|
|
from flask import Flask, request, jsonify , render_template
|
|
import imageio.v3 as imageio
|
|
import numpy as np
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
|
|
class tokenizer():
|
|
def __init__(self, threshold=5):
|
|
self.word2idx = {}
|
|
self.idx2word = {}
|
|
self.threshold = threshold
|
|
self.word2count = {}
|
|
|
|
def build_vocab(self, corpus):
|
|
print('buiding vocab......')
|
|
tokens = corpus.lower().split()
|
|
for token in tokens:
|
|
self.word2count[token] = self.word2count.get(token, 0) + 1
|
|
idx = 0
|
|
for word, count in self.word2count.items():
|
|
if count >= self.threshold:
|
|
self.word2idx[word] = idx
|
|
self.idx2word[idx] = word
|
|
idx += 1
|
|
print(f'Vocab size: {len(self.idx2word)}')
|
|
|
|
def encode(self, sentence):
|
|
tokens = sentence.lower().split()
|
|
return [self.word2idx.get(token, self.word2idx['<unk>']) for token in tokens]
|
|
|
|
def decode(self, indices):
|
|
return ' '.join([self.idx2word.get(idx, '<unk>') for idx in indices])
|
|
|
|
def save_vocab(self, filepath):
|
|
with open(filepath, 'w') as f:
|
|
json.dump({'word2idx': self.word2idx, 'idx2word': self.idx2word}, f)
|
|
|
|
def load_vocab(self, filepath):
|
|
with open(filepath, 'r') as f:
|
|
data = json.load(f)
|
|
self.word2idx = data['word2idx']
|
|
self.idx2word = {int(k): v for k, v in data['idx2word'].items()}
|
|
|
|
|
|
class CNNEncoder(nn.Module):
|
|
def __init__(self, embed_size, num_groups=32):
|
|
super(CNNEncoder, self).__init__()
|
|
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
for param in resnet.parameters():
|
|
param.requires_grad = False
|
|
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
|
|
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
|
|
self.gn = nn.GroupNorm(num_groups, embed_size)
|
|
|
|
def forward(self, images):
|
|
with torch.no_grad():
|
|
features = self.resnet(images)
|
|
features = features.view(features.size(0), -1)
|
|
features = self.gn(self.linear(features))
|
|
return features
|
|
|
|
|
|
class RNNDecoder(nn.Module):
|
|
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
|
|
super(RNNDecoder, self).__init__()
|
|
self.embed = nn.Embedding(vocab_size, embed_size)
|
|
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
|
|
self.linear = nn.Linear(hidden_size, vocab_size)
|
|
self.embed_size = embed_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
|
|
def forward(self, features, captions):
|
|
embeddings = self.embed(captions)
|
|
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
|
|
hiddens, _ = self.lstm(embeddings)
|
|
outputs = self.linear(hiddens[:, 1:, :])
|
|
return outputs
|
|
|
|
def sample(self, features, states=None, max_len=20):
|
|
sampled_ids = [vocab.word2idx['<start>']]
|
|
inputs = features.unsqueeze(1)
|
|
start_token = torch.tensor([vocab.word2idx['<start>']]).to(device).unsqueeze(0)
|
|
inputs = torch.cat((features.unsqueeze(1), self.embed(start_token)), 1)
|
|
for i in range(max_len):
|
|
hiddens, states = self.lstm(inputs, states)
|
|
outputs = self.linear(hiddens[:, -1, :])
|
|
_, predicted = outputs.max(1)
|
|
sampled_ids.append(predicted.item())
|
|
if predicted.item() == vocab.word2idx['<end>']:
|
|
break
|
|
inputs = self.embed(predicted).unsqueeze(1)
|
|
return sampled_ids
|
|
|
|
|
|
class im2text_model(nn.Module):
|
|
def __init__(self, cnn_encoder, rnn_decoder):
|
|
super(im2text_model, self).__init__()
|
|
self.encoder = cnn_encoder
|
|
self.decoder = rnn_decoder
|
|
|
|
def forward(self, images, captions):
|
|
features = self.encoder(images)
|
|
outputs = self.decoder(features, captions)
|
|
return outputs
|
|
|
|
def sample(self, images, states=None):
|
|
features = self.encoder(images)
|
|
sampled_ids = self.decoder.sample(features, states)
|
|
return sampled_ids
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
vocab = tokenizer()
|
|
vocab.load_vocab('vocab_full.json')
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = torch.load('im2text_model_full.pt', map_location=torch.device('cpu'))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224), antialias=True),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
@app.route('/upload', methods=['POST'])
|
|
def upload_image():
|
|
if 'file' not in request.files:
|
|
return jsonify({'error': 'No file part'})
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
return jsonify({'error': 'No selected file'})
|
|
if file:
|
|
|
|
image = Image.open(file.stream)
|
|
if image.format in ['GIF', 'WebP', 'PNG']:
|
|
image = image.convert('RGB')
|
|
|
|
|
|
byte_io = BytesIO()
|
|
image.save(byte_io, 'JPEG')
|
|
byte_io.seek(0)
|
|
|
|
image = imageio.imread(byte_io)
|
|
if len(image.shape) == 2:
|
|
image = np.stack([image] * 3, axis=0)
|
|
else:
|
|
image = np.transpose(image, (2, 0, 1))
|
|
image = torch.tensor(image / 255.0).float()
|
|
image = transform(image).unsqueeze(0).to(device)
|
|
|
|
with torch.no_grad():
|
|
generated_caption = model.sample(image)
|
|
generated_caption_text = vocab.decode(generated_caption)
|
|
|
|
return jsonify({'caption': generated_caption_text})
|
|
if __name__ == '__main__':
|
|
app.run(debug=True)
|
|
|