Upload 3 files
Browse files- scripts/README.md +80 -0
- scripts/batch-caption.py +339 -0
- scripts/prompt-examples.json +7 -0
scripts/README.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Documentation for the scripts in the `scripts` directory, starting with `batch-caption.py`, which is used to run JoyCaption in bulk. Other scripts might be added in the future.
|
2 |
+
|
3 |
+
# batch-caption.py
|
4 |
+
|
5 |
+
## Basic Command
|
6 |
+
|
7 |
+
To run the script, use the following command:
|
8 |
+
|
9 |
+
```sh
|
10 |
+
./batch-caption.py --glob "path/to/images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone."
|
11 |
+
```
|
12 |
+
|
13 |
+
This command will caption all the `.jpg` images in the specified directory using the provided prompt, writing `.txt` files alongside each image.
|
14 |
+
|
15 |
+
## Command-Line Arguments
|
16 |
+
|
17 |
+
**Note**: You must specify either `--glob` or `--filelist` to provide images, and either `--prompt` or `--prompt-file` to provide a prompt for caption generation.
|
18 |
+
|
19 |
+
| Argument | Description | Default |
|
20 |
+
| ------------------ | ---------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
21 |
+
| `--glob` | Glob pattern to find images | N/A |
|
22 |
+
| `--filelist` | File containing a list of images | N/A |
|
23 |
+
| `--prompt` | Prompt to use for caption generation | N/A |
|
24 |
+
| `--prompt-file` | JSON file containing prompts | N/A |
|
25 |
+
| `--batch-size` | Batch size for image processing | 1 |
|
26 |
+
| `--greedy` | Use greedy decoding instead of sampling | False |
|
27 |
+
| `--temperature` | Sampling temperature (used when not using greedy decoding) | 0.6 |
|
28 |
+
| `--top-p` | Top-p sampling value (nucleus sampling) | 0.9 |
|
29 |
+
| `--top-k` | Top-k sampling value | None |
|
30 |
+
| `--max-new-tokens` | Maximum length of the generated caption (in tokens) | 256 |
|
31 |
+
| `--num-workers` | Number of workers loading images in parallel | 4 |
|
32 |
+
| `--model` | Pre-trained model to use | [fancyfeast/llama-joycaption-alpha-two-hf-llava](https://huggingface.co/fancyfeast/llama-joycaption-alpha-two-hf-llava) |
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
### Examples
|
37 |
+
|
38 |
+
1. **Caption images with a specific prompt**
|
39 |
+
|
40 |
+
```sh
|
41 |
+
./batch-caption.py --glob "images/*.png" --prompt "Write a descriptive caption for this image in a formal tone."
|
42 |
+
```
|
43 |
+
|
44 |
+
2. **Use a JSON file for prompts**
|
45 |
+
|
46 |
+
```sh
|
47 |
+
python batch-caption.py --filelist "image_paths.txt" --prompt-file "prompts.json"
|
48 |
+
```
|
49 |
+
|
50 |
+
3. **Use Greedy Decoding**
|
51 |
+
|
52 |
+
```sh
|
53 |
+
python batch-caption.py --glob "images/*.jpg" --prompt "Write a descriptive caption for this image in a formal tone." --greedy
|
54 |
+
```
|
55 |
+
|
56 |
+
## Prompt Handling
|
57 |
+
|
58 |
+
- For a list of prompts that the model understands, please refer to the project's root README.
|
59 |
+
|
60 |
+
- You can specify a prompt directly using the `--prompt` argument or use a JSON file containing a list of prompts with weights using `--prompt-file`.
|
61 |
+
|
62 |
+
- If multiple prompts are specified in the prompt file, the prompt used for each image will be randomly selected.
|
63 |
+
|
64 |
+
- **Prompt File Format**: The JSON file should contain either strings or objects with `prompt` and `weight` fields.
|
65 |
+
|
66 |
+
- **Weighting**: The `weight` field indicates the probability of selecting a particular prompt during caption generation. Higher weights make a prompt more likely to be chosen. For example, if one prompt has a weight of 2.0 and another has a weight of 1.0, the first prompt will be twice as likely to be used.
|
67 |
+
|
68 |
+
Example `prompts.json`:
|
69 |
+
|
70 |
+
```json
|
71 |
+
[
|
72 |
+
{ "prompt": "Describe the scene in detail.", "weight": 2.0 },
|
73 |
+
{ "prompt": "Summarize the main elements of the image.", "weight": 1.0 }
|
74 |
+
]
|
75 |
+
```
|
76 |
+
|
77 |
+
## Output
|
78 |
+
|
79 |
+
- Captions are saved as `.txt` files in the same directory as the corresponding image.
|
80 |
+
- If a `.txt` caption file already exists for an image, the script will skip that image.
|
scripts/batch-caption.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Use JoyCaption to caption images.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import dataclasses
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import PIL.Image
|
14 |
+
import torch
|
15 |
+
import torch.amp
|
16 |
+
import torchvision.transforms.functional as TVF
|
17 |
+
from PIL import Image
|
18 |
+
from torch.utils.data import DataLoader, Dataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from transformers import (
|
21 |
+
AutoTokenizer,
|
22 |
+
LlavaForConditionalGeneration,
|
23 |
+
PreTrainedTokenizer,
|
24 |
+
PreTrainedTokenizerFast,
|
25 |
+
)
|
26 |
+
from typing import Union
|
27 |
+
|
28 |
+
def none_or_type(value, desired_type):
|
29 |
+
if value == "None":
|
30 |
+
return None
|
31 |
+
return desired_type(value)
|
32 |
+
|
33 |
+
DEFAULT_PROMPT = "Write a descriptive caption for this image in a formal tone."
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument('-i', '--input', type=str, help='Input image')
|
37 |
+
parser.add_argument("--glob", type=str, help="Glob pattern to find images")
|
38 |
+
parser.add_argument("--filelist", type=str, help="File containing list of images")
|
39 |
+
parser.add_argument("--prompt", type=str, help="Prompt to use")
|
40 |
+
parser.add_argument("--prompt-file", type=str, help="JSON file containing prompts to use")
|
41 |
+
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
|
42 |
+
parser.add_argument("--greedy", action="store_true", help="Use greedy decoding instead of sampling")
|
43 |
+
parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
|
44 |
+
parser.add_argument("--top-p", type=lambda x: none_or_type(x, float), default=0.9, help="Top-p sampling")
|
45 |
+
parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
|
46 |
+
parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
|
47 |
+
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
|
48 |
+
parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
|
49 |
+
#parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use")
|
50 |
+
parser.add_argument("--nf4", action="store_true", default=False, help="Use NF4 (default: bfloat16)")
|
51 |
+
|
52 |
+
PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
|
53 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
54 |
+
|
55 |
+
|
56 |
+
@dataclasses.dataclass
|
57 |
+
class Prompt:
|
58 |
+
prompt: str
|
59 |
+
weight: float
|
60 |
+
|
61 |
+
|
62 |
+
#@torch.no_grad()
|
63 |
+
@torch.inference_mode()
|
64 |
+
def main():
|
65 |
+
# Logging
|
66 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
|
67 |
+
|
68 |
+
# Parse arguments
|
69 |
+
args = parser.parse_args()
|
70 |
+
logging.info(f"Arguments: {args}")
|
71 |
+
IS_NF4 = args.nf4
|
72 |
+
|
73 |
+
# Make sure we have a prompt or a prompt file
|
74 |
+
prompts = parse_prompts(args.prompt, args.prompt_file)
|
75 |
+
|
76 |
+
# Find the images
|
77 |
+
image_paths = find_images(args.glob, args.filelist, args.input)
|
78 |
+
if len(image_paths) == 0:
|
79 |
+
logging.warning("No images found")
|
80 |
+
return
|
81 |
+
logging.info(f"Found {len(image_paths)} images")
|
82 |
+
|
83 |
+
# Ignore all images that already have captions
|
84 |
+
image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()]
|
85 |
+
|
86 |
+
# Load JoyCaption
|
87 |
+
from transformers import BitsAndBytesConfig
|
88 |
+
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=torch.bfloat16,
|
89 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
91 |
+
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
92 |
+
if IS_NF4: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, quantization_config=nf4_config, torch_dtype="bfloat16", device_map=device)
|
93 |
+
else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map=device)
|
94 |
+
assert isinstance(llava_model, LlavaForConditionalGeneration)
|
95 |
+
|
96 |
+
dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
|
97 |
+
dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, num_workers=args.num_workers, shuffle=False, drop_last=False, batch_size=args.batch_size)
|
98 |
+
end_of_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
|
99 |
+
end_of_turn_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
100 |
+
assert isinstance(end_of_header_id, int) and isinstance(end_of_turn_id, int)
|
101 |
+
|
102 |
+
pbar = tqdm(total=len(image_paths), desc="Captioning images...", dynamic_ncols=True)
|
103 |
+
for batch in dataloader:
|
104 |
+
vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
|
105 |
+
vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
|
106 |
+
language_device = llava_model.language_model.get_input_embeddings().weight.device
|
107 |
+
|
108 |
+
# Move to GPU
|
109 |
+
pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
|
110 |
+
input_ids = batch['input_ids'].to(language_device, non_blocking=True)
|
111 |
+
attention_mask = batch['attention_mask'].to(language_device, non_blocking=True)
|
112 |
+
|
113 |
+
# Normalize the image
|
114 |
+
pixel_values = pixel_values / 255.0
|
115 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
116 |
+
pixel_values = pixel_values.to(vision_dtype)
|
117 |
+
|
118 |
+
# Generate the captions
|
119 |
+
generate_ids = llava_model.generate(
|
120 |
+
input_ids=input_ids,
|
121 |
+
pixel_values=pixel_values,
|
122 |
+
attention_mask=attention_mask,
|
123 |
+
max_new_tokens=args.max_new_tokens,
|
124 |
+
do_sample=not args.greedy,
|
125 |
+
suppress_tokens=None,
|
126 |
+
use_cache=True,
|
127 |
+
temperature=args.temperature,
|
128 |
+
top_k=args.top_k,
|
129 |
+
top_p=args.top_p,
|
130 |
+
)
|
131 |
+
|
132 |
+
# Trim off the prompts
|
133 |
+
assert isinstance(generate_ids, torch.Tensor)
|
134 |
+
generate_ids = generate_ids.tolist()
|
135 |
+
generate_ids = [trim_off_prompt(ids, end_of_header_id, end_of_turn_id) for ids in generate_ids]
|
136 |
+
|
137 |
+
# Decode the captions
|
138 |
+
captions = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
139 |
+
captions = [c.strip() for c in captions]
|
140 |
+
|
141 |
+
for path, caption in zip(batch['paths'], captions):
|
142 |
+
write_caption(Path(path), caption)
|
143 |
+
|
144 |
+
pbar.update(len(captions))
|
145 |
+
|
146 |
+
|
147 |
+
def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]:
|
148 |
+
# Trim off the prompt
|
149 |
+
while True:
|
150 |
+
try:
|
151 |
+
i = input_ids.index(eoh_id)
|
152 |
+
except ValueError:
|
153 |
+
break
|
154 |
+
|
155 |
+
input_ids = input_ids[i + 1:]
|
156 |
+
|
157 |
+
# Trim off the end
|
158 |
+
try:
|
159 |
+
i = input_ids.index(eot_id)
|
160 |
+
except ValueError:
|
161 |
+
return input_ids
|
162 |
+
|
163 |
+
return input_ids[:i]
|
164 |
+
|
165 |
+
|
166 |
+
def write_caption(image_path: Path, caption: str):
|
167 |
+
caption_path = image_path.with_suffix(".txt")
|
168 |
+
|
169 |
+
try:
|
170 |
+
f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists
|
171 |
+
except FileExistsError:
|
172 |
+
logging.warning(f"Caption file '{caption_path}' already exists")
|
173 |
+
return
|
174 |
+
except Exception as e:
|
175 |
+
logging.error(f"Failed to open caption file '{caption_path}': {e}")
|
176 |
+
return
|
177 |
+
|
178 |
+
try:
|
179 |
+
os.write(f, caption.encode("utf-8"))
|
180 |
+
os.close(f)
|
181 |
+
except Exception as e:
|
182 |
+
logging.error(f"Failed to write caption to '{caption_path}': {e}")
|
183 |
+
return
|
184 |
+
|
185 |
+
|
186 |
+
class ImageDataset(Dataset):
|
187 |
+
def __init__(self, prompts: list[Prompt], paths: list[Path], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], image_token_id: int, image_seq_length: int):
|
188 |
+
self.prompts = prompts
|
189 |
+
self.paths = paths
|
190 |
+
self.tokenizer = tokenizer
|
191 |
+
self.image_token_id = image_token_id
|
192 |
+
self.image_seq_length = image_seq_length
|
193 |
+
self.pad_token_id = tokenizer.pad_token_id
|
194 |
+
|
195 |
+
def __len__(self):
|
196 |
+
return len(self.paths)
|
197 |
+
|
198 |
+
def __getitem__(self, idx: int) -> dict:
|
199 |
+
path = self.paths[idx]
|
200 |
+
|
201 |
+
# Pick a prompt
|
202 |
+
prompt_str = random.choices(self.prompts, weights=[p.weight for p in self.prompts])[0].prompt
|
203 |
+
|
204 |
+
# Preprocess image
|
205 |
+
# NOTE: I don't use the Processor here and instead do it manually.
|
206 |
+
# This is because in my testing a simple resize in Pillow yields higher quality results than the Processor,
|
207 |
+
# and the Processor had some buggy behavior on some images.
|
208 |
+
# And yes, with the so400m model, the model expects the image to be squished into a square, not padded.
|
209 |
+
try:
|
210 |
+
image = Image.open(path)
|
211 |
+
if image.size != (384, 384):
|
212 |
+
image = image.resize((384, 384), Image.LANCZOS)
|
213 |
+
image = image.convert("RGB")
|
214 |
+
pixel_values = TVF.pil_to_tensor(image)
|
215 |
+
except Exception as e:
|
216 |
+
logging.error(f"Failed to load image '{path}': {e}")
|
217 |
+
pixel_values = None # Will be filtered out later
|
218 |
+
|
219 |
+
# Build the conversation
|
220 |
+
convo = [
|
221 |
+
{
|
222 |
+
"role": "system",
|
223 |
+
"content": "You are a helpful image captioner.",
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"role": "user",
|
227 |
+
"content": prompt_str,
|
228 |
+
},
|
229 |
+
]
|
230 |
+
|
231 |
+
# Format the conversation
|
232 |
+
convo_string = self.tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
|
233 |
+
assert isinstance(convo_string, str)
|
234 |
+
|
235 |
+
# Tokenize the conversation
|
236 |
+
convo_tokens = self.tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
|
237 |
+
|
238 |
+
# Repeat the image tokens
|
239 |
+
input_tokens = []
|
240 |
+
for token in convo_tokens:
|
241 |
+
if token == self.image_token_id:
|
242 |
+
input_tokens.extend([self.image_token_id] * self.image_seq_length)
|
243 |
+
else:
|
244 |
+
input_tokens.append(token)
|
245 |
+
|
246 |
+
input_ids = torch.tensor(input_tokens, dtype=torch.long)
|
247 |
+
attention_mask = torch.ones_like(input_ids)
|
248 |
+
|
249 |
+
return {
|
250 |
+
'path': path,
|
251 |
+
'pixel_values': pixel_values,
|
252 |
+
'input_ids': input_ids,
|
253 |
+
'attention_mask': attention_mask,
|
254 |
+
}
|
255 |
+
|
256 |
+
def collate_fn(self, batch: list[dict]) -> dict:
|
257 |
+
# Filter out images that failed to load
|
258 |
+
batch = [item for item in batch if item['pixel_values'] is not None]
|
259 |
+
|
260 |
+
# Pad input_ids and attention_mask
|
261 |
+
# Have to use left padding because HF's generate can't handle right padding it seems
|
262 |
+
max_length = max(item['input_ids'].shape[0] for item in batch)
|
263 |
+
n_pad = [max_length - item['input_ids'].shape[0] for item in batch]
|
264 |
+
input_ids = torch.stack([torch.nn.functional.pad(item['input_ids'], (n, 0), value=self.pad_token_id) for item, n in zip(batch, n_pad)])
|
265 |
+
attention_mask = torch.stack([torch.nn.functional.pad(item['attention_mask'], (n, 0), value=0) for item, n in zip(batch, n_pad)])
|
266 |
+
|
267 |
+
# Stack pixel values
|
268 |
+
pixel_values = torch.stack([item['pixel_values'] for item in batch])
|
269 |
+
|
270 |
+
# Paths
|
271 |
+
paths = [item['path'] for item in batch]
|
272 |
+
|
273 |
+
return {
|
274 |
+
'paths': paths,
|
275 |
+
'pixel_values': pixel_values,
|
276 |
+
'input_ids': input_ids,
|
277 |
+
'attention_mask': attention_mask,
|
278 |
+
}
|
279 |
+
|
280 |
+
|
281 |
+
def parse_prompts(prompt_str: Union[str, None], prompt_file: Union[str, None]) -> list[Prompt]:
|
282 |
+
if prompt_str is not None and prompt_file is not None:
|
283 |
+
raise ValueError("Cannot specify both --prompt and --prompt-file")
|
284 |
+
|
285 |
+
if prompt_str is not None:
|
286 |
+
return [Prompt(prompt=prompt_str, weight=1.0)]
|
287 |
+
|
288 |
+
if prompt_file is None:
|
289 |
+
return [Prompt(prompt=DEFAULT_PROMPT, weight=1.0)]
|
290 |
+
#raise ValueError("Must specify either --prompt or --prompt-file")
|
291 |
+
|
292 |
+
data = json.loads(Path(prompt_file).read_text())
|
293 |
+
|
294 |
+
if not isinstance(data, list):
|
295 |
+
raise ValueError("Expected JSON file to contain a list of prompts")
|
296 |
+
|
297 |
+
prompts = []
|
298 |
+
|
299 |
+
for item in data:
|
300 |
+
if isinstance(item, str):
|
301 |
+
prompts.append(Prompt(prompt=item, weight=1.0))
|
302 |
+
elif isinstance(item, dict) and "prompt" in item and "weight" in item and isinstance(item["prompt"], str) and isinstance(item["weight"], (int, float)):
|
303 |
+
prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"]))
|
304 |
+
else:
|
305 |
+
raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}")
|
306 |
+
|
307 |
+
if len(prompts) == 0:
|
308 |
+
raise ValueError("No prompts found in JSON file")
|
309 |
+
|
310 |
+
if sum(p.weight for p in prompts) <= 0.0:
|
311 |
+
raise ValueError("Prompt weights must sum to a positive number")
|
312 |
+
|
313 |
+
return prompts
|
314 |
+
|
315 |
+
|
316 |
+
def find_images(glob: Union[str, None], filelist: Union[str, Path, None], input: str) -> list[Path]:
|
317 |
+
if glob is None and filelist is None and input is None:
|
318 |
+
raise ValueError("Must specify either --glob or --filelist or --input")
|
319 |
+
|
320 |
+
paths = []
|
321 |
+
|
322 |
+
if glob is not None:
|
323 |
+
paths.extend(Path(".").glob(glob))
|
324 |
+
|
325 |
+
if filelist is not None:
|
326 |
+
paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != ""))
|
327 |
+
|
328 |
+
if input is not None:
|
329 |
+
paths.append(input)
|
330 |
+
|
331 |
+
return paths
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == "__main__":
|
335 |
+
main()
|
336 |
+
|
337 |
+
# https://github.com/huggingface/peft/issues/156
|
338 |
+
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
|
339 |
+
# https://github.com/huggingface/peft/issues/1831
|
scripts/prompt-examples.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"Write a descriptive caption for this image in a formal tone.",
|
3 |
+
{
|
4 |
+
"prompt": "Write a medium-length stable diffusion prompt for this image.",
|
5 |
+
"weight": 0.5
|
6 |
+
}
|
7 |
+
]
|