yoarkyang commited on
Commit
84ec82e
1 Parent(s): bee3ae9

use pytorch to do resize and clip to reduce gpu memory usage.

Browse files

Change `resize_and_pad` to its original version as in molmo code so images are resized with pytorch utilities. Using tensorflow to resize image leads to large gpu memory increase, and the inference is impossible with 4090 24G, which is against the purpose of this project.

Files changed (1) hide show
  1. image_preprocessing_molmo.py +20 -20
image_preprocessing_molmo.py CHANGED
@@ -85,26 +85,26 @@ def resize_and_pad(
85
  scaled_height = int(np.array(height, np.float32) * image_scale)
86
  scaled_width = int(np.array(width, np.float32) * image_scale)
87
 
88
- # if resize_method == "tensorflow":
89
- # FIXME remove
90
- import tensorflow as tf
91
- image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
92
- image = tf.image.resize(
93
- image,
94
- [scaled_height, scaled_width],
95
- method=tf.image.ResizeMethod.BILINEAR,
96
- antialias=True,
97
- )
98
- image = tf.clip_by_value(image, 0.0, 1.0)
99
- image = image.numpy()
100
- # else:
101
- # image = torch.permute(torch.from_numpy(image), [2, 0, 1])
102
- # image = convert_image_dtype(image) # resize in flaot32
103
- # image = torchvision.transforms.Resize(
104
- # [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
105
- # )(image)
106
- # image = torch.clip(image, 0.0, 1.0)
107
- # image = torch.permute(image, [1, 2, 0]).numpy()
108
 
109
  top_pad = (desired_height - scaled_height) // 2
110
  left_pad = (desired_width - scaled_width) // 2
 
85
  scaled_height = int(np.array(height, np.float32) * image_scale)
86
  scaled_width = int(np.array(width, np.float32) * image_scale)
87
 
88
+ if resize_method == "tensorflow":
89
+ # this option leads to large gpu mem increase likely due to how tensorflow handle memory allocation
90
+ import tensorflow as tf
91
+ image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
92
+ image = tf.image.resize(
93
+ image,
94
+ [scaled_height, scaled_width],
95
+ method=tf.image.ResizeMethod.BILINEAR,
96
+ antialias=True,
97
+ )
98
+ image = tf.clip_by_value(image, 0.0, 1.0)
99
+ image = image.numpy()
100
+ else:
101
+ image = torch.permute(torch.from_numpy(image), [2, 0, 1])
102
+ image = convert_image_dtype(image) # resize in flaot32
103
+ image = torchvision.transforms.Resize(
104
+ [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
105
+ )(image)
106
+ image = torch.clip(image, 0.0, 1.0)
107
+ image = torch.permute(image, [1, 2, 0]).numpy()
108
 
109
  top_pad = (desired_height - scaled_height) // 2
110
  left_pad = (desired_width - scaled_width) // 2