John6666 commited on
Commit
c566e3a
1 Parent(s): ee4c79e

Upload 3 files

Browse files
scripts/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Documentation for the scripts in the `scripts` directory, starting with `batch-caption.py`, which is used to run JoyCaption in bulk. Other scripts might be added in the future.
2
+
3
+ # batch-caption.py
4
+
5
+ ## Basic Command
6
+
7
+ To run the script, use the following command:
8
+
9
+ ```sh
10
+ ./batch-caption.py --glob "path/to/images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone."
11
+ ```
12
+
13
+ This command will caption all the `.jpg` images in the specified directory using the provided prompt, writing `.txt` files alongside each image.
14
+
15
+ ## Command-Line Arguments
16
+
17
+ **Note**: You must specify either `--glob` or `--filelist` to provide images, and either `--prompt` or `--prompt-file` to provide a prompt for caption generation.
18
+
19
+ | Argument | Description | Default |
20
+ | ------------------ | ---------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
21
+ | `--glob` | Glob pattern to find images | N/A |
22
+ | `--filelist` | File containing a list of images | N/A |
23
+ | `--prompt` | Prompt to use for caption generation | N/A |
24
+ | `--prompt-file` | JSON file containing prompts | N/A |
25
+ | `--batch-size` | Batch size for image processing | 1 |
26
+ | `--greedy` | Use greedy decoding instead of sampling | False |
27
+ | `--temperature` | Sampling temperature (used when not using greedy decoding) | 0.6 |
28
+ | `--top-p` | Top-p sampling value (nucleus sampling) | 0.9 |
29
+ | `--top-k` | Top-k sampling value | None |
30
+ | `--max-new-tokens` | Maximum length of the generated caption (in tokens) | 256 |
31
+ | `--num-workers` | Number of workers loading images in parallel | 4 |
32
+ | `--model` | Pre-trained model to use | [fancyfeast/llama-joycaption-alpha-two-hf-llava](https://huggingface.co/fancyfeast/llama-joycaption-alpha-two-hf-llava) |
33
+
34
+
35
+
36
+ ### Examples
37
+
38
+ 1. **Caption images with a specific prompt**
39
+
40
+ ```sh
41
+ ./batch-caption.py --glob "images/*.png" --prompt "Write a descriptive caption for this image in a formal tone."
42
+ ```
43
+
44
+ 2. **Use a JSON file for prompts**
45
+
46
+ ```sh
47
+ python batch-caption.py --filelist "image_paths.txt" --prompt-file "prompts.json"
48
+ ```
49
+
50
+ 3. **Use Greedy Decoding**
51
+
52
+ ```sh
53
+ python batch-caption.py --glob "images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone." --greedy
54
+ ```
55
+
56
+ ## Prompt Handling
57
+
58
+ - For a list of prompts that the model understands, please refer to the project's root README.
59
+
60
+ - You can specify a prompt directly using the `--prompt` argument or use a JSON file containing a list of prompts with weights using `--prompt-file`.
61
+
62
+ - If multiple prompts are specified in the prompt file, the prompt used for each image will be randomly selected.
63
+
64
+ - **Prompt File Format**: The JSON file should contain either strings or objects with `prompt` and `weight` fields.
65
+
66
+ - **Weighting**: The `weight` field indicates the probability of selecting a particular prompt during caption generation. Higher weights make a prompt more likely to be chosen. For example, if one prompt has a weight of 2.0 and another has a weight of 1.0, the first prompt will be twice as likely to be used.
67
+
68
+ Example `prompts.json`:
69
+
70
+ ```json
71
+ [
72
+ { "prompt": "Describe the scene in detail.", "weight": 2.0 },
73
+ { "prompt": "Summarize the main elements of the image.", "weight": 1.0 }
74
+ ]
75
+ ```
76
+
77
+ ## Output
78
+
79
+ - Captions are saved as `.txt` files in the same directory as the corresponding image.
80
+ - If a `.txt` caption file already exists for an image, the script will skip that image.
scripts/batch-caption.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Use JoyCaption to caption images.
4
+ """
5
+ import argparse
6
+ import dataclasses
7
+ import json
8
+ import logging
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import PIL.Image
14
+ import torch
15
+ import torch.amp
16
+ import torchvision.transforms.functional as TVF
17
+ from PIL import Image
18
+ from torch.utils.data import DataLoader, Dataset
19
+ from tqdm import tqdm
20
+ from transformers import (
21
+ AutoTokenizer,
22
+ LlavaForConditionalGeneration,
23
+ PreTrainedTokenizer,
24
+ PreTrainedTokenizerFast,
25
+ )
26
+ from typing import Union
27
+
28
+ def none_or_type(value, desired_type):
29
+ if value == "None":
30
+ return None
31
+ return desired_type(value)
32
+
33
+ DEFAULT_PROMPT = "Write a descriptive caption for this image in a formal tone."
34
+
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('-i', '--input', type=str, help='Input image')
37
+ parser.add_argument("--glob", type=str, help="Glob pattern to find images")
38
+ parser.add_argument("--filelist", type=str, help="File containing list of images")
39
+ parser.add_argument("--prompt", type=str, help="Prompt to use")
40
+ parser.add_argument("--prompt-file", type=str, help="JSON file containing prompts to use")
41
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
42
+ parser.add_argument("--greedy", action="store_true", help="Use greedy decoding instead of sampling")
43
+ parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
44
+ parser.add_argument("--top-p", type=lambda x: none_or_type(x, float), default=0.9, help="Top-p sampling")
45
+ parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
46
+ parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
47
+ parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
48
+ parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
49
+ #parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use")
50
+ parser.add_argument("--nf4", action="store_true", default=False, help="Use NF4 (default: bfloat16)")
51
+
52
+ PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
53
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class Prompt:
58
+ prompt: str
59
+ weight: float
60
+
61
+
62
+ #@torch.no_grad()
63
+ @torch.inference_mode()
64
+ def main():
65
+ # Logging
66
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
67
+
68
+ # Parse arguments
69
+ args = parser.parse_args()
70
+ logging.info(f"Arguments: {args}")
71
+ IS_NF4 = args.nf4
72
+
73
+ # Make sure we have a prompt or a prompt file
74
+ prompts = parse_prompts(args.prompt, args.prompt_file)
75
+
76
+ # Find the images
77
+ image_paths = find_images(args.glob, args.filelist, args.input)
78
+ if len(image_paths) == 0:
79
+ logging.warning("No images found")
80
+ return
81
+ logging.info(f"Found {len(image_paths)} images")
82
+
83
+ # Ignore all images that already have captions
84
+ image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()]
85
+
86
+ # Load JoyCaption
87
+ from transformers import BitsAndBytesConfig
88
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=torch.bfloat16,
89
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
90
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
91
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
92
+ if IS_NF4: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, quantization_config=nf4_config, torch_dtype="bfloat16", device_map=device)
93
+ else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map=device)
94
+ assert isinstance(llava_model, LlavaForConditionalGeneration)
95
+
96
+ dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
97
+ dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, num_workers=args.num_workers, shuffle=False, drop_last=False, batch_size=args.batch_size)
98
+ end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
99
+ end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
100
+ assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
101
+
102
+ pbar = tqdm(total=len(image_paths), desc="Captioning images...", dynamic_ncols=True)
103
+ for batch in dataloader:
104
+ vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
105
+ vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
106
+ language_device = llava_model.language_model.get_input_embeddings().weight.device
107
+
108
+ # Move to GPU
109
+ pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
110
+ input_ids = batch['input_ids'].to(language_device, non_blocking=True)
111
+ attention_mask = batch['attention_mask'].to(language_device, non_blocking=True)
112
+
113
+ # Normalize the image
114
+ pixel_values = pixel_values / 255.0
115
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
116
+ pixel_values = pixel_values.to(vision_dtype)
117
+
118
+ # Generate the captions
119
+ generate_ids = llava_model.generate(
120
+ input_ids=input_ids,
121
+ pixel_values=pixel_values,
122
+ attention_mask=attention_mask,
123
+ max_new_tokens=args.max_new_tokens,
124
+ do_sample=not args.greedy,
125
+ suppress_tokens=None,
126
+ use_cache=True,
127
+ temperature=args.temperature,
128
+ top_k=args.top_k,
129
+ top_p=args.top_p,
130
+ )
131
+
132
+ # Trim off the prompts
133
+ assert isinstance(generate_ids, torch.Tensor)
134
+ generate_ids = generate_ids.tolist()
135
+ generate_ids = [trim_off_prompt(ids, end_of_header_id, end_of_turn_id) for ids in generate_ids]
136
+
137
+ # Decode the captions
138
+ captions = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
139
+ captions = [c.strip() for c in captions]
140
+
141
+ for path, caption in zip(batch['paths'], captions):
142
+ write_caption(Path(path), caption)
143
+
144
+ pbar.update(len(captions))
145
+
146
+
147
+ def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]:
148
+ # Trim off the prompt
149
+ while True:
150
+ try:
151
+ i = input_ids.index(eoh_id)
152
+ except ValueError:
153
+ break
154
+
155
+ input_ids = input_ids[i + 1:]
156
+
157
+ # Trim off the end
158
+ try:
159
+ i = input_ids.index(eot_id)
160
+ except ValueError:
161
+ return input_ids
162
+
163
+ return input_ids[:i]
164
+
165
+
166
+ def write_caption(image_path: Path, caption: str):
167
+ caption_path = image_path.with_suffix(".txt")
168
+
169
+ try:
170
+ f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists
171
+ except FileExistsError:
172
+ logging.warning(f"Caption file '{caption_path}' already exists")
173
+ return
174
+ except Exception as e:
175
+ logging.error(f"Failed to open caption file '{caption_path}': {e}")
176
+ return
177
+
178
+ try:
179
+ os.write(f, caption.encode("utf-8"))
180
+ os.close(f)
181
+ except Exception as e:
182
+ logging.error(f"Failed to write caption to '{caption_path}': {e}")
183
+ return
184
+
185
+
186
+ class ImageDataset(Dataset):
187
+ def __init__(self, prompts: list[Prompt], paths: list[Path], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], image_token_id: int, image_seq_length: int):
188
+ self.prompts = prompts
189
+ self.paths = paths
190
+ self.tokenizer = tokenizer
191
+ self.image_token_id = image_token_id
192
+ self.image_seq_length = image_seq_length
193
+ self.pad_token_id = tokenizer.pad_token_id
194
+
195
+ def __len__(self):
196
+ return len(self.paths)
197
+
198
+ def __getitem__(self, idx: int) -> dict:
199
+ path = self.paths[idx]
200
+
201
+ # Pick a prompt
202
+ prompt_str = random.choices(self.prompts, weights=[p.weight for p in self.prompts])[0].prompt
203
+
204
+ # Preprocess image
205
+ # NOTE: I don't use the Processor here and instead do it manually.
206
+ # This is because in my testing a simple resize in Pillow yields higher quality results than the Processor,
207
+ # and the Processor had some buggy behavior on some images.
208
+ # And yes, with the so400m model, the model expects the image to be squished into a square, not padded.
209
+ try:
210
+ image = Image.open(path)
211
+ if image.size != (384, 384):
212
+ image = image.resize((384, 384), Image.LANCZOS)
213
+ image = image.convert("RGB")
214
+ pixel_values = TVF.pil_to_tensor(image)
215
+ except Exception as e:
216
+ logging.error(f"Failed to load image '{path}': {e}")
217
+ pixel_values = None # Will be filtered out later
218
+
219
+ # Build the conversation
220
+ convo = [
221
+ {
222
+ "role": "system",
223
+ "content": "You are a helpful image captioner.",
224
+ },
225
+ {
226
+ "role": "user",
227
+ "content": prompt_str,
228
+ },
229
+ ]
230
+
231
+ # Format the conversation
232
+ convo_string = self.tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
233
+ assert isinstance(convo_string, str)
234
+
235
+ # Tokenize the conversation
236
+ convo_tokens = self.tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
237
+
238
+ # Repeat the image tokens
239
+ input_tokens = []
240
+ for token in convo_tokens:
241
+ if token == self.image_token_id:
242
+ input_tokens.extend([self.image_token_id] * self.image_seq_length)
243
+ else:
244
+ input_tokens.append(token)
245
+
246
+ input_ids = torch.tensor(input_tokens, dtype=torch.long)
247
+ attention_mask = torch.ones_like(input_ids)
248
+
249
+ return {
250
+ 'path': path,
251
+ 'pixel_values': pixel_values,
252
+ 'input_ids': input_ids,
253
+ 'attention_mask': attention_mask,
254
+ }
255
+
256
+ def collate_fn(self, batch: list[dict]) -> dict:
257
+ # Filter out images that failed to load
258
+ batch = [item for item in batch if item['pixel_values'] is not None]
259
+
260
+ # Pad input_ids and attention_mask
261
+ # Have to use left padding because HF's generate can't handle right padding it seems
262
+ max_length = max(item['input_ids'].shape[0] for item in batch)
263
+ n_pad = [max_length - item['input_ids'].shape[0] for item in batch]
264
+ input_ids = torch.stack([torch.nn.functional.pad(item['input_ids'], (n, 0), value=self.pad_token_id) for item, n in zip(batch, n_pad)])
265
+ attention_mask = torch.stack([torch.nn.functional.pad(item['attention_mask'], (n, 0), value=0) for item, n in zip(batch, n_pad)])
266
+
267
+ # Stack pixel values
268
+ pixel_values = torch.stack([item['pixel_values'] for item in batch])
269
+
270
+ # Paths
271
+ paths = [item['path'] for item in batch]
272
+
273
+ return {
274
+ 'paths': paths,
275
+ 'pixel_values': pixel_values,
276
+ 'input_ids': input_ids,
277
+ 'attention_mask': attention_mask,
278
+ }
279
+
280
+
281
+ def parse_prompts(prompt_str: Union[str, None], prompt_file: Union[str, None]) -> list[Prompt]:
282
+ if prompt_str is not None and prompt_file is not None:
283
+ raise ValueError("Cannot specify both --prompt and --prompt-file")
284
+
285
+ if prompt_str is not None:
286
+ return [Prompt(prompt=prompt_str, weight=1.0)]
287
+
288
+ if prompt_file is None:
289
+ return [Prompt(prompt=DEFAULT_PROMPT, weight=1.0)]
290
+ #raise ValueError("Must specify either --prompt or --prompt-file")
291
+
292
+ data = json.loads(Path(prompt_file).read_text())
293
+
294
+ if not isinstance(data, list):
295
+ raise ValueError("Expected JSON file to contain a list of prompts")
296
+
297
+ prompts = []
298
+
299
+ for item in data:
300
+ if isinstance(item, str):
301
+ prompts.append(Prompt(prompt=item, weight=1.0))
302
+ elif isinstance(item, dict) and "prompt" in item and "weight" in item and isinstance(item["prompt"], str) and isinstance(item["weight"], (int, float)):
303
+ prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"]))
304
+ else:
305
+ raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}")
306
+
307
+ if len(prompts) == 0:
308
+ raise ValueError("No prompts found in JSON file")
309
+
310
+ if sum(p.weight for p in prompts) <= 0.0:
311
+ raise ValueError("Prompt weights must sum to a positive number")
312
+
313
+ return prompts
314
+
315
+
316
+ def find_images(glob: Union[str, None], filelist: Union[str, Path, None], input: str) -> list[Path]:
317
+ if glob is None and filelist is None and input is None:
318
+ raise ValueError("Must specify either --glob or --filelist or --input")
319
+
320
+ paths = []
321
+
322
+ if glob is not None:
323
+ paths.extend(Path(".").glob(glob))
324
+
325
+ if filelist is not None:
326
+ paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != ""))
327
+
328
+ if input is not None:
329
+ paths.append(input)
330
+
331
+ return paths
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()
336
+
337
+ # https://github.com/huggingface/peft/issues/156
338
+ # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
339
+ # https://github.com/huggingface/peft/issues/1831
scripts/prompt-examples.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [
2
+ "Write a descriptive caption for this image in a formal tone.",
3
+ {
4
+ "prompt": "Write a medium-length stable diffusion prompt for this image.",
5
+ "weight": 0.5
6
+ }
7
+ ]