File size: 2,197 Bytes
a5ed3da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch

import numpy as np
import random
import os
from typing import Tuple, Optional


def get_prompt_template(mode: str = 'default') -> Tuple[str, int, int]:
    '''
    Generate a prompt template based on the specified mode.

    Args:
        mode (str, optional): The mode for selecting the prompt template. Default is 'default'.

    Returns:
        Tuple[str, int, int]: A tuple containing the generated prompt template, the position of the placeholder '{}',
                             and the length of the prompt.

    Notes:
        If the mode is 'random', a random prompt template is chosen from a predefined list.
    '''
    prompt_template = 'A photo of {}'

    if mode == 'random':
        prompt_templates = [
            'a photo of a {}', 'a photograph of a {}', 'an image of a {}', '{}',
            'a cropped photo of a {}', 'a good photo of a {}', 'a photo of one {}',
            'a bad photo of a {}', 'a photo of the {}', 'a photo of {}', 'a blurry photo of a {}',
            'a picture of a {}', 'a photo of a scene where {}'
        ]
        prompt_template = random.choice(prompt_templates)

    # Calculate prompt length and text position
    prompt_length = 1 + len(prompt_template.split(' ')) + 1 - 1  # eos, sos => 1 + 1, {} => -1
    text_pos_at_prompt = 1 + prompt_template.split(' ').index('{}')

    return prompt_template, text_pos_at_prompt, prompt_length


# Reproducibility
def fix_seed(seed: int = 0) -> None:
    '''
    Set seeds for random number generators to ensure reproducibility.

    Args:
        seed (int, optional): The seed value. Default is 0.
    '''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


def seed_worker(worker_id: int) -> None:
    '''
    Set a seed for a worker process to ensure reproducibility in PyTorch DataLoader.

    Args:
        worker_id (int): The ID of the worker process.
    '''
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)