Anole / interleaved_generation.py
xuefengli
update
7362797
raw
history blame
5.27 kB
import json
import os
import torch
import argparse
from PIL import Image
from chameleon.inference.chameleon import ChameleonInferenceModel, Options
from constants import (
MODEL_7B_PATH,
TOKENIZER_TEXT_PATH,
TOKENIZER_IMAGE_CFG_PATH,
TOKENIZER_IMAGE_PATH,
)
from typing import List, Tuple
import logging
# Set up the logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def split_token_sequence(
tokens: torch.LongTensor,
boi: int,
eoi: int
) -> List[Tuple[str, torch.LongTensor]]:
"""
Split a sequence of tokens into text and image segments.
Args:
tokens (torch.LongTensor): The token sequence.
boi (int): Begin of image token.
eoi (int): End of image token.
Returns:
List[Tuple[str, torch.LongTensor]]: List of tuples indicating segment type and tokens.
"""
batch_size, _ = tokens.shape
assert batch_size == 1, "Batch size must be 1"
device = tokens.device
tokens = tokens[0] # remove batch dimension
tokens = tokens.to(device)
segments = []
current_segment = []
in_image_seg = False
for token in tokens:
if token == boi:
# if entering an image segment, save the current text segment (if any)
if current_segment:
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
current_segment = []
in_image_seg = True
elif token == eoi and in_image_seg:
# if exiting an image segment, save the current image segment
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
current_segment = []
in_image_seg = False
else:
current_segment.append(token)
# save any remaining tokens
if current_segment:
if in_image_seg:
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
else:
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
return segments
def main(args: argparse.Namespace):
"""Main function to generate and process model output."""
# Load Chameleon model
model = ChameleonInferenceModel(
MODEL_7B_PATH.as_posix(),
TOKENIZER_TEXT_PATH.as_posix(),
TOKENIZER_IMAGE_CFG_PATH.as_posix(),
TOKENIZER_IMAGE_PATH.as_posix(),
)
# Print model configuration
logging.info(f"Model path: {MODEL_7B_PATH}")
logging.info(f"Text tokenizer path: {TOKENIZER_TEXT_PATH}")
logging.info(f"Image tokenizer config path: {TOKENIZER_IMAGE_CFG_PATH}")
logging.info(f"Image tokenizer path: {TOKENIZER_IMAGE_PATH}")
# Generate options
options = Options()
# Prepare prompt
instructions = [args.instruction]
batch_prompt_ui = []
for instruction in instructions:
if isinstance(instruction, Tuple):
inst, image_path = instruction
batch_prompt_ui += [
[
{"type": "image", "value": f"file:{image_path}"},
{"type": "text", "value": inst}
],
]
else:
batch_prompt_ui += [
[
{"type": "text", "value": instruction}
],
]
# generate
tokens: torch.LongTensor = model.generate(
batch_prompt_ui=batch_prompt_ui,
options=options
)
# split
boi, eoi = model.vocab.begin_image, model.vocab.end_image # 8197(boi), 8196(eoi)
segments = split_token_sequence(tokens, boi, eoi)
# decode
os.makedirs(args.save_dir, exist_ok=True)
segments_data = []
for seg_id, (seg_type, seg_tokens) in enumerate(segments):
if seg_type == "image_seg":
assert seg_tokens.shape[1] == 1024
img = model.decode_image(seg_tokens)[0]
image_path = os.path.join(args.save_dir, f"{seg_id}.png")
img.save(image_path)
segments_data.append({"type": "image", "content": image_path})
else:
assert seg_type == "text_seg"
decoded_text = model.decode_text(seg_tokens)[0]
segments_data.append({"type": "text", "content": decoded_text})
jsonl_path = os.path.join("./segments.jsonl")
with open(jsonl_path, 'w') as jsonl_file:
for segment in segments_data:
jsonl_file.write(json.dumps(segment) + '\n')
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Generate interleaved image-text content based on text instructions.")
parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for interleaved image-text generation.")
parser.add_argument("-s", "--save_dir", type=str, default="./outputs/interleaved/", help="The directory to save the generated images.")
args: argparse.Namespace = parser.parse_args()
return args
if __name__ == "__main__":
args: argparse.Namespace = parse_arguments()
main(args)