jaketae commited on
Commit
f1d50b1
1 Parent(s): 696f287

feature: add streamlit backbone

Browse files
Files changed (8) hide show
  1. .gitignore +135 -0
  2. app.py +13 -0
  3. image2text.py +12 -0
  4. koclip/__init__.py +1 -0
  5. koclip/config.py +109 -0
  6. koclip/model.py +471 -0
  7. text2image.py +14 -0
  8. utils.py +21 -0
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # macOS
2
+ .DS_Store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # pyenv
80
+ # For a library or package, you might want to ignore these files since the code is
81
+ # intended to run in multiple environments; otherwise, check them in:
82
+ # .python-version
83
+
84
+ # pipenv
85
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
87
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
88
+ # install all needed dependencies.
89
+ Pipfile
90
+ Pipfile.lock
91
+
92
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
93
+ __pypackages__/
94
+
95
+ # Environments
96
+ .env
97
+ .venv
98
+ env/
99
+ venv/
100
+ ENV/
101
+ env.bak/
102
+ venv.bak/
103
+
104
+ # Spyder project settings
105
+ .spyderproject
106
+ .spyproject
107
+
108
+ # Intellij project settings
109
+ .idea/
110
+ .iml
111
+
112
+ # Rope project settings
113
+ .ropeproject
114
+
115
+ # mkdocs documentation
116
+ /site
117
+
118
+ # mypy
119
+ .mypy_cache/
120
+ .dmypy.json
121
+ dmypy.json
122
+
123
+ # Pyre type checker
124
+ .pyre/
125
+
126
+ # pytype static type analyzer
127
+ .pytype/
128
+
129
+ # Cython debug symbols
130
+ cython_debug/
131
+
132
+ # static files generated from Django application
133
+ media
134
+ staticfiles
135
+ /tags
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import image2text
4
+ import text2image
5
+
6
+
7
+ PAGES = {"Text to Image": text2image, "Image to Text": image2text}
8
+
9
+ st.sidebar.title("Navigation")
10
+ model = st.sidebar.radio("Model", ["koclip/koclip", "koclip/koclip-large"])
11
+ page = st.sidebar.radio("Go to", list(PAGES.keys()))
12
+
13
+ PAGES[page].app(model)
image2text.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from utils import load_model
4
+
5
+
6
+ def app(model_name):
7
+ model, processor = load_model(model_name)
8
+
9
+ st.title("Text to Image Retrieval")
10
+ st.markdown("""
11
+ Some text goes in here.
12
+ """)
koclip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import FlaxHybridCLIP
koclip/config.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class HybridCLIPConfig(PretrainedConfig):
10
+ r"""
11
+ :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
12
+ :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
13
+ defining the text model and vision model configs.
14
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
15
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
16
+ Args:
17
+ text_config_dict (:obj:`dict`):
18
+ Dictionary of configuration options that defines text model config.
19
+ vision_config_dict (:obj:`dict`):
20
+ Dictionary of configuration options that defines vison model config.
21
+ projection_dim (:obj:`int`, `optional`, defaults to 512):
22
+ Dimentionality of text and vision projection layers.
23
+ kwargs (`optional`):
24
+ Dictionary of keyword arguments.
25
+ Examples::
26
+ >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
27
+ >>> # Initializing a BERT and CLIP configuration
28
+ >>> config_text = BertConfig()
29
+ >>> config_vision = CLIPConfig()
30
+ >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
31
+ >>> # Initializing a BERT and CLIPVision model
32
+ >>> model = EncoderDecoderModel(config=config)
33
+ >>> # Accessing the model configuration
34
+ >>> config_text = model.config.text_config
35
+ >>> config_vision = model.config.vision_config
36
+ >>> # Saving the model, including its configuration
37
+ >>> model.save_pretrained('my-model')
38
+ >>> # loading model and config from pretrained folder
39
+ >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
40
+ >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
41
+ """
42
+
43
+ model_type = "hybrid-clip"
44
+ is_composition = True
45
+
46
+ def __init__(self, projection_dim=512, **kwargs):
47
+ super().__init__(**kwargs)
48
+
49
+ if "text_config" not in kwargs:
50
+ raise ValueError("`text_config` can not be `None`.")
51
+
52
+ if "vision_config" not in kwargs:
53
+ raise ValueError("`vision_config` can not be `None`.")
54
+
55
+ text_config = kwargs.pop("text_config")
56
+ vision_config = kwargs.pop("vision_config")
57
+
58
+ text_model_type = text_config.pop("model_type")
59
+ vision_model_type = vision_config.pop("model_type")
60
+
61
+ from transformers import AutoConfig
62
+
63
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
64
+
65
+ if vision_model_type == "clip":
66
+ self.vision_config = AutoConfig.for_model(
67
+ vision_model_type, **vision_config
68
+ ).vision_config
69
+ elif vision_model_type == "clip_vision_model":
70
+ from transformers import CLIPVisionConfig
71
+
72
+ self.vision_config = CLIPVisionConfig(**vision_config)
73
+ else:
74
+ self.vision_config = AutoConfig.for_model(
75
+ vision_model_type, **vision_config
76
+ )
77
+
78
+ self.projection_dim = projection_dim
79
+ self.initializer_factor = 1.0
80
+
81
+ @classmethod
82
+ def from_text_vision_configs(
83
+ cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs
84
+ ):
85
+ r"""
86
+ Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
87
+ vision model configuration.
88
+ Returns:
89
+ :class:`HybridCLIPConfig`: An instance of a configuration object
90
+ """
91
+
92
+ return cls(
93
+ text_config=text_config.to_dict(),
94
+ vision_config=vision_config.to_dict(),
95
+ **kwargs
96
+ )
97
+
98
+ def to_dict(self):
99
+ """
100
+ Serializes this instance to a Python dictionary. Override the default
101
+ :meth:`~transformers.PretrainedConfig.to_dict`.
102
+ Returns:
103
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
104
+ """
105
+ output = copy.deepcopy(self.__dict__)
106
+ output["text_config"] = self.text_config.to_dict()
107
+ output["vision_config"] = self.vision_config.to_dict()
108
+ output["model_type"] = self.__class__.model_type
109
+ return output
koclip/model.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax.core.frozen_dict import FrozenDict
22
+ from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
23
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
24
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
25
+ from transformers.utils import logging
26
+
27
+ from .config import HybridCLIPConfig
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class FlaxHybridCLIPModule(nn.Module):
33
+ config: HybridCLIPConfig
34
+ dtype: jnp.dtype = jnp.float32
35
+
36
+ def setup(self):
37
+ text_config = self.config.text_config
38
+ vision_config = self.config.vision_config
39
+
40
+ self.projection_dim = self.config.projection_dim
41
+ self.text_embed_dim = text_config.hidden_size
42
+ self.vision_embed_dim = vision_config.hidden_size
43
+
44
+ text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
+ vision_module = FLAX_MODEL_MAPPING.get(
46
+ self.config.vision_config.__class__, FlaxCLIPVisionModel
47
+ ).module_class
48
+
49
+ self.text_model = text_module(text_config, dtype=self.dtype)
50
+ self.vision_model = vision_module(vision_config, dtype=self.dtype)
51
+
52
+ self.visual_projection = nn.Dense(
53
+ self.projection_dim,
54
+ dtype=self.dtype,
55
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
56
+ use_bias=False,
57
+ )
58
+ self.text_projection = nn.Dense(
59
+ self.projection_dim,
60
+ dtype=self.dtype,
61
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
62
+ use_bias=False,
63
+ )
64
+ self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
65
+
66
+ def __call__(
67
+ self,
68
+ input_ids=None,
69
+ pixel_values=None,
70
+ attention_mask=None,
71
+ position_ids=None,
72
+ token_type_ids=None,
73
+ deterministic: bool = True,
74
+ output_attentions=None,
75
+ output_hidden_states=None,
76
+ return_dict=None,
77
+ ):
78
+ return_dict = (
79
+ return_dict if return_dict is not None else self.config.return_dict
80
+ )
81
+
82
+ vision_outputs = self.vision_model(
83
+ pixel_values=pixel_values,
84
+ deterministic=deterministic,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
+
90
+ text_outputs = self.text_model(
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ token_type_ids=token_type_ids,
94
+ position_ids=position_ids,
95
+ deterministic=deterministic,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ )
100
+
101
+ image_embeds = vision_outputs[1]
102
+ image_embeds = self.visual_projection(image_embeds)
103
+
104
+ text_embeds = text_outputs[1]
105
+ text_embeds = self.text_projection(text_embeds)
106
+
107
+ # normalized features
108
+ image_embeds = image_embeds / jnp.linalg.norm(
109
+ image_embeds, axis=-1, keepdims=True
110
+ )
111
+ text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
112
+
113
+ # cosine similarity as logits
114
+ logit_scale = jnp.exp(self.logit_scale)
115
+ logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
116
+ logits_per_image = logits_per_text.T
117
+
118
+ if not return_dict:
119
+ return (
120
+ logits_per_image,
121
+ logits_per_text,
122
+ text_embeds,
123
+ image_embeds,
124
+ text_outputs,
125
+ vision_outputs,
126
+ )
127
+
128
+ return FlaxCLIPOutput(
129
+ logits_per_image=logits_per_image,
130
+ logits_per_text=logits_per_text,
131
+ text_embeds=text_embeds,
132
+ image_embeds=image_embeds,
133
+ text_model_output=text_outputs,
134
+ vision_model_output=vision_outputs,
135
+ )
136
+
137
+
138
+ class FlaxHybridCLIP(FlaxPreTrainedModel):
139
+ config_class = HybridCLIPConfig
140
+ module_class = FlaxHybridCLIPModule
141
+
142
+ def __init__(
143
+ self,
144
+ config: HybridCLIPConfig,
145
+ input_shape: Optional[Tuple] = None,
146
+ seed: int = 0,
147
+ dtype: jnp.dtype = jnp.float32,
148
+ **kwargs,
149
+ ):
150
+ if input_shape is None:
151
+ input_shape = (
152
+ (1, 1),
153
+ (
154
+ 1,
155
+ config.vision_config.image_size,
156
+ config.vision_config.image_size,
157
+ 3,
158
+ ),
159
+ )
160
+
161
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
162
+ super().__init__(
163
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
164
+ )
165
+
166
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
167
+ # init input tensor
168
+ input_ids = jnp.zeros(input_shape[0], dtype="i4")
169
+ position_ids = jnp.broadcast_to(
170
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]
171
+ )
172
+ token_type_ids = jnp.ones_like(input_ids)
173
+ attention_mask = jnp.ones_like(input_ids)
174
+
175
+ pixel_values = jax.random.normal(rng, input_shape[1])
176
+
177
+ params_rng, dropout_rng = jax.random.split(rng)
178
+ rngs = {"params": params_rng, "dropout": dropout_rng}
179
+
180
+ return self.module.init(
181
+ rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids
182
+ )["params"]
183
+
184
+ def __call__(
185
+ self,
186
+ input_ids,
187
+ pixel_values,
188
+ attention_mask=None,
189
+ position_ids=None,
190
+ token_type_ids=None,
191
+ params: dict = None,
192
+ dropout_rng: jax.random.PRNGKey = None,
193
+ train: bool = False,
194
+ output_attentions: Optional[bool] = None,
195
+ output_hidden_states: Optional[bool] = None,
196
+ return_dict: Optional[bool] = None,
197
+ ):
198
+ output_attentions = (
199
+ output_attentions
200
+ if output_attentions is not None
201
+ else self.config.output_attentions
202
+ )
203
+ output_hidden_states = (
204
+ output_hidden_states
205
+ if output_hidden_states is not None
206
+ else self.config.output_hidden_states
207
+ )
208
+ return_dict = (
209
+ return_dict if return_dict is not None else self.config.return_dict
210
+ )
211
+
212
+ if position_ids is None:
213
+ position_ids = jnp.broadcast_to(
214
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
215
+ )
216
+
217
+ if token_type_ids is None:
218
+ token_type_ids = jnp.zeros_like(input_ids)
219
+
220
+ if attention_mask is None:
221
+ attention_mask = jnp.ones_like(input_ids)
222
+
223
+ # Handle any PRNG if needed
224
+ rngs = {}
225
+ if dropout_rng is not None:
226
+ rngs["dropout"] = dropout_rng
227
+
228
+ return self.module.apply(
229
+ {"params": params or self.params},
230
+ jnp.array(input_ids, dtype="i4"),
231
+ jnp.array(pixel_values, dtype=jnp.float32),
232
+ jnp.array(attention_mask, dtype="i4"),
233
+ jnp.array(position_ids, dtype="i4"),
234
+ jnp.array(token_type_ids, dtype="i4"),
235
+ not train,
236
+ output_attentions,
237
+ output_hidden_states,
238
+ return_dict,
239
+ rngs=rngs,
240
+ )
241
+
242
+ def get_text_features(
243
+ self,
244
+ input_ids,
245
+ attention_mask=None,
246
+ position_ids=None,
247
+ token_type_ids=None,
248
+ dropout_rng: jax.random.PRNGKey = None,
249
+ train=False,
250
+ ):
251
+ r"""
252
+ Args:
253
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
254
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
255
+ provide it.
256
+ Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
257
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
258
+ for details.
259
+ `What are input IDs? <../glossary.html#input-ids>`__
260
+ Returns:
261
+ text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
262
+ obtained by applying the projection layer to the pooled output of text model.
263
+ """
264
+ if position_ids is None:
265
+ position_ids = jnp.broadcast_to(
266
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
267
+ )
268
+
269
+ if token_type_ids is None:
270
+ token_type_ids = jnp.zeros_like(input_ids)
271
+
272
+ if attention_mask is None:
273
+ attention_mask = jnp.ones_like(input_ids)
274
+
275
+ # Handle any PRNG if needed
276
+ rngs = {}
277
+ if dropout_rng is not None:
278
+ rngs["dropout"] = dropout_rng
279
+
280
+ def _get_features(
281
+ module,
282
+ input_ids,
283
+ attention_mask,
284
+ position_ids,
285
+ token_type_ids,
286
+ deterministic,
287
+ ):
288
+ text_outputs = module.text_model(
289
+ input_ids=input_ids,
290
+ attention_mask=attention_mask,
291
+ position_ids=position_ids,
292
+ token_type_ids=token_type_ids,
293
+ deterministic=deterministic,
294
+ )
295
+ pooled_output = text_outputs[1]
296
+ text_features = module.text_projection(pooled_output)
297
+ return text_features
298
+
299
+ return self.module.apply(
300
+ {"params": self.params},
301
+ jnp.array(input_ids, dtype="i4"),
302
+ jnp.array(attention_mask, dtype="i4"),
303
+ jnp.array(position_ids, dtype="i4"),
304
+ jnp.array(token_type_ids, dtype="i4"),
305
+ not train,
306
+ method=_get_features,
307
+ rngs=rngs,
308
+ )
309
+
310
+ def get_image_features(
311
+ self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False
312
+ ):
313
+ r"""
314
+ Args:
315
+ pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
316
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
317
+ using :class:`~transformers.ImageFeatureExtractionMixin`. See
318
+ :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
319
+ Returns:
320
+ image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
321
+ obtained by applying the projection layer to the pooled output of vision model.
322
+ """
323
+
324
+ # Handle any PRNG if needed
325
+ rngs = {}
326
+ if dropout_rng is not None:
327
+ rngs["dropout"] = dropout_rng
328
+
329
+ def _get_features(module, pixel_values, deterministic):
330
+ vision_outputs = module.vision_model(
331
+ pixel_values=pixel_values, deterministic=deterministic
332
+ )
333
+ pooled_output = vision_outputs[1] # pooled_output
334
+ image_features = module.visual_projection(pooled_output)
335
+ return image_features
336
+
337
+ return self.module.apply(
338
+ {"params": self.params},
339
+ jnp.array(pixel_values, dtype=jnp.float32),
340
+ not train,
341
+ method=_get_features,
342
+ rngs=rngs,
343
+ )
344
+
345
+ @classmethod
346
+ def from_text_vision_pretrained(
347
+ cls,
348
+ text_model_name_or_path: str = None,
349
+ vision_model_name_or_path: str = None,
350
+ *model_args,
351
+ **kwargs,
352
+ ) -> FlaxPreTrainedModel:
353
+ """
354
+ Params:
355
+ text_model_name_or_path (:obj: `str`, `optional`):
356
+ Information necessary to initiate the text model. Can be either:
357
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
358
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
359
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
360
+ - A path to a `directory` containing model weights saved using
361
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
362
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
363
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
364
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
365
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
366
+ vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
367
+ Information necessary to initiate the vision model. Can be either:
368
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
369
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
370
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
371
+ - A path to a `directory` containing model weights saved using
372
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
373
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
374
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
375
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
376
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
377
+ model_args (remaining positional arguments, `optional`):
378
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
379
+ kwargs (remaining dictionary of keyword arguments, `optional`):
380
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
381
+ :obj:`output_attentions=True`).
382
+ - To update the text configuration, use the prefix `text_` for each configuration parameter.
383
+ - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
384
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
385
+ Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
386
+ Example::
387
+ >>> from transformers import FlaxHybridCLIP
388
+ >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
389
+ >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
390
+ >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
391
+ >>> # saving model after fine-tuning
392
+ >>> model.save_pretrained("./bert-clip")
393
+ >>> # load fine-tuned model
394
+ >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
395
+ """
396
+
397
+ kwargs_text = {
398
+ argument[len("text_") :]: value
399
+ for argument, value in kwargs.items()
400
+ if argument.startswith("text_")
401
+ }
402
+
403
+ kwargs_vision = {
404
+ argument[len("vision_") :]: value
405
+ for argument, value in kwargs.items()
406
+ if argument.startswith("vision_")
407
+ }
408
+
409
+ # remove text, vision kwargs from kwargs
410
+ for key in kwargs_text.keys():
411
+ del kwargs["text_" + key]
412
+ for key in kwargs_vision.keys():
413
+ del kwargs["vision_" + key]
414
+
415
+ # Load and initialize the text and vision model
416
+ text_model = kwargs_text.pop("model", None)
417
+ if text_model is None:
418
+ assert (
419
+ text_model_name_or_path is not None
420
+ ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
421
+ from transformers import FlaxAutoModel
422
+
423
+ if "config" not in kwargs_text:
424
+ from transformers import AutoConfig
425
+
426
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
427
+ kwargs_text["config"] = text_config
428
+
429
+ text_model = FlaxAutoModel.from_pretrained(
430
+ text_model_name_or_path, *model_args, **kwargs_text
431
+ )
432
+
433
+ vision_model = kwargs_vision.pop("model", None)
434
+ if vision_model is None:
435
+ assert (
436
+ vision_model_name_or_path is not None
437
+ ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
438
+ from transformers import FlaxAutoModel
439
+
440
+ if "config" not in kwargs_vision:
441
+ from transformers import AutoConfig
442
+
443
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
444
+ kwargs_vision["config"] = vision_config
445
+
446
+ vision_model = FlaxAutoModel.from_pretrained(
447
+ vision_model_name_or_path, *model_args, **kwargs_vision
448
+ )
449
+
450
+ # instantiate config with corresponding kwargs
451
+ dtype = kwargs.pop("dtype", jnp.float32)
452
+ config = HybridCLIPConfig.from_text_vision_configs(
453
+ text_model.config, vision_model.config, **kwargs
454
+ )
455
+
456
+ # init model
457
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
458
+
459
+ if vision_config.model_type == "clip":
460
+ model.params["vision_model"]["vision_model"] = vision_model.params[
461
+ "vision_model"
462
+ ]
463
+ model.params["visual_projection"]["kernel"] = vision_model.params[
464
+ "visual_projection"
465
+ ]["kernel"]
466
+ else:
467
+ model.params["vision_model"] = vision_model.params
468
+
469
+ model.params["text_model"] = text_model.params
470
+
471
+ return model
text2image.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from utils import load_model
4
+
5
+
6
+ def app(model_name):
7
+ model, processor = load_model(model_name)
8
+
9
+
10
+ st.title("Text to Image Retrieval")
11
+ st.markdown("""
12
+ Some text goes in here.
13
+ """)
14
+
utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
3
+
4
+ from koclip import FlaxHybridCLIP
5
+
6
+
7
+ @st.cache(allow_output_mutation=True)
8
+ def load_model(model_name="koclip/koclip"):
9
+ assert model_name in {"koclip/koclip", "koclip/koclip-large"}
10
+ model = FlaxHybridCLIP.from_pretrained(model_name)
11
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
+ processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
13
+ if model_name == "koclip/koclip-large":
14
+ processor.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-large-patch16-224")
15
+ return model, processor
16
+
17
+ @st.cache(allow_output_mutation=True)
18
+ def load_model_v2(model_name="koclip/koclip"):
19
+ model = FlaxHybridCLIP.from_pretrained(model_name)
20
+ processor = CLIPProcessor.from_pretrained(model_name)
21
+ return model, processor