|
import os |
|
from pathlib import Path |
|
import numpy as np |
|
import tempfile |
|
import tensorflow as tf |
|
import mediapy |
|
from PIL import Image |
|
import cog |
|
|
|
from eval import interpolator, util |
|
|
|
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) |
|
|
|
|
|
class Predictor(cog.Predictor): |
|
def setup(self): |
|
import tensorflow as tf |
|
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) |
|
self.interpolator = interpolator.Interpolator("pretrained_models/film_net/Style/saved_model", None) |
|
|
|
|
|
self.batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) |
|
|
|
@cog.input( |
|
"frame1", |
|
type=Path, |
|
help="The first input frame", |
|
) |
|
@cog.input( |
|
"frame2", |
|
type=Path, |
|
help="The second input frame", |
|
) |
|
@cog.input( |
|
"times_to_interpolate", |
|
type=int, |
|
default=1, |
|
min=1, |
|
max=8, |
|
help="Controls the number of times the frame interpolator is invoked If set to 1, the output will be the " |
|
"sub-frame at t=0.5; when set to > 1, the output will be the interpolation video with " |
|
"(2^times_to_interpolate + 1) frames, fps of 30.", |
|
) |
|
def predict(self, frame1, frame2, times_to_interpolate): |
|
INPUT_EXT = ['.png', '.jpg', '.jpeg'] |
|
assert os.path.splitext(str(frame1))[-1] in INPUT_EXT and os.path.splitext(str(frame2))[-1] in INPUT_EXT, \ |
|
"Please provide png, jpg or jpeg images." |
|
|
|
|
|
img1 = Image.open(str(frame1)) |
|
img2 = Image.open(str(frame2)) |
|
if not img1.size == img2.size: |
|
img1 = img1.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1]))) |
|
img2 = img2.crop((0, 0, min(img1.size[0], img2.size[0]), min(img1.size[1], img2.size[1]))) |
|
frame1 = 'new_frame1.png' |
|
frame2 = 'new_frame2.png' |
|
img1.save(frame1) |
|
img2.save(frame2) |
|
|
|
if times_to_interpolate == 1: |
|
|
|
image_1 = util.read_image(str(frame1)) |
|
image_batch_1 = np.expand_dims(image_1, axis=0) |
|
|
|
|
|
image_2 = util.read_image(str(frame2)) |
|
image_batch_2 = np.expand_dims(image_2, axis=0) |
|
|
|
|
|
|
|
mid_frame = self.interpolator.interpolate(image_batch_1, image_batch_2, self.batch_dt)[0] |
|
out_path = Path(tempfile.mkdtemp()) / "out.png" |
|
util.write_image(str(out_path), mid_frame) |
|
return out_path |
|
|
|
|
|
input_frames = [str(frame1), str(frame2)] |
|
|
|
frames = list( |
|
util.interpolate_recursively_from_files( |
|
input_frames, times_to_interpolate, self.interpolator)) |
|
print('Interpolated frames generated, saving now as output video.') |
|
|
|
ffmpeg_path = util.get_ffmpeg_path() |
|
mediapy.set_ffmpeg(ffmpeg_path) |
|
out_path = Path(tempfile.mkdtemp()) / "out.mp4" |
|
mediapy.write_video(str(out_path), frames, fps=30) |
|
return out_path |
|
|