File size: 2,226 Bytes
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""An abstraction layer for prompting different models."""

from __future__ import annotations

import enum

from fastchat.model.model_adapter import get_conversation_template


class Task(enum.Enum):
    """Different system prompt styles."""

    CHAT = "chat"
    CHAT_CONCISE = "chat-concise"
    INSTRUCT = "instruct"
    INSTRUCT_CONCISE = "instruct-concise"


SYSTEM_PROMPTS = {
    Task.CHAT: (
        "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
    ),
    Task.CHAT_CONCISE: (
        "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
        "The assistant's answers are very concise. "
    ),
    Task.INSTRUCT: (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request. "
    ),
    Task.INSTRUCT_CONCISE: (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request. "
        "The response should be very concise. "
    ),
}

def get_system_prompt(task: Task | str) -> str:
    """Get the system prompt for a given task."""
    if isinstance(task, str):
        task = Task(task)
    return SYSTEM_PROMPTS[task]


def apply_model_characteristics(
    system_prompt: str,
    prompt: str,
    model_name: str,
) -> tuple[str, str | None, list[int]]:
    """Apply and return model-specific differences."""
    conv = get_conversation_template(model_name)

    if "llama-2" in model_name.lower():
        conv.system = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
    elif "stablelm" in model_name.lower():
        conv.system = f"""<|SYSTEM|># {system_prompt}\n"""
    else:
        conv.system = system_prompt
    conv.messages = []
    conv.offset = 0

    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], "")

    stop_str = None if conv.stop_str is None or not conv.stop_str else conv.stop_str

    return conv.get_prompt(), stop_str, (conv.stop_token_ids or [])