Wonderplex commited on
Commit
1f90e24
·
unverified ·
1 Parent(s): 79cc507

sotopia-pi prompt template fix (#66)

Browse files
sotopia_pi_generate.py → sotopia_generate.py RENAMED
@@ -3,6 +3,7 @@ import os
3
  from typing import TypeVar
4
  from functools import cache
5
  import logging
 
6
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
@@ -22,9 +23,10 @@ from langchain.prompts import (
22
  PromptTemplate,
23
  )
24
  from langchain.schema import BaseOutputParser, OutputParserException
 
 
25
  from message_classes import ActionType, AgentAction
26
  from utils import format_docstring
27
-
28
  from langchain_callback_handler import LoggingCallbackHandler
29
 
30
  HF_TOKEN_KEY_FILE="./hf_token.key"
@@ -89,7 +91,7 @@ def prepare_model(model_name):
89
  model = AutoModelForCausalLM.from_pretrained(
90
  "mistralai/Mistral-7B-Instruct-v0.1",
91
  cache_dir="./.cache",
92
- device_map='cuda'
93
  )
94
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
95
 
@@ -98,7 +100,7 @@ def prepare_model(model_name):
98
  model = AutoModelForCausalLM.from_pretrained(
99
  "mistralai/Mistral-7B-Instruct-v0.1",
100
  cache_dir="./.cache",
101
- device_map='cuda',
102
  quantization_config=BitsAndBytesConfig(
103
  load_in_4bit=True,
104
  bnb_4bit_use_double_quant=True,
@@ -114,7 +116,7 @@ def prepare_model(model_name):
114
  model = AutoModelForCausalLM.from_pretrained(
115
  "mistralai/Mistral-7B-Instruct-v0.1",
116
  cache_dir="./.cache",
117
- device_map='cuda'
118
  )
119
 
120
  else:
@@ -131,7 +133,7 @@ def obtain_chain_hf(
131
  max_tokens: int = 2700
132
  ) -> LLMChain:
133
  human_message_prompt = HumanMessagePromptTemplate(
134
- prompt=PromptTemplate(template=template, input_variables=input_variables)
135
  )
136
  chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
137
  model, tokenizer = prepare_model(model_name)
@@ -148,6 +150,7 @@ def obtain_chain_hf(
148
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
149
  return chain
150
 
 
151
  def generate(
152
  model_name: str,
153
  template: str,
 
3
  from typing import TypeVar
4
  from functools import cache
5
  import logging
6
+ import json
7
 
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
23
  PromptTemplate,
24
  )
25
  from langchain.schema import BaseOutputParser, OutputParserException
26
+ import spaces
27
+
28
  from message_classes import ActionType, AgentAction
29
  from utils import format_docstring
 
30
  from langchain_callback_handler import LoggingCallbackHandler
31
 
32
  HF_TOKEN_KEY_FILE="./hf_token.key"
 
91
  model = AutoModelForCausalLM.from_pretrained(
92
  "mistralai/Mistral-7B-Instruct-v0.1",
93
  cache_dir="./.cache",
94
+ # device_map='cuda'
95
  )
96
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
97
 
 
100
  model = AutoModelForCausalLM.from_pretrained(
101
  "mistralai/Mistral-7B-Instruct-v0.1",
102
  cache_dir="./.cache",
103
+ # device_map='cuda',
104
  quantization_config=BitsAndBytesConfig(
105
  load_in_4bit=True,
106
  bnb_4bit_use_double_quant=True,
 
116
  model = AutoModelForCausalLM.from_pretrained(
117
  "mistralai/Mistral-7B-Instruct-v0.1",
118
  cache_dir="./.cache",
119
+ # device_map='cuda'
120
  )
121
 
122
  else:
 
133
  max_tokens: int = 2700
134
  ) -> LLMChain:
135
  human_message_prompt = HumanMessagePromptTemplate(
136
+ prompt=PromptTemplate(template="[INST] " + template + " [/INST]", input_variables=input_variables)
137
  )
138
  chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
139
  model, tokenizer = prepare_model(model_name)
 
150
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
151
  return chain
152
 
153
+
154
  def generate(
155
  model_name: str,
156
  template: str,
sotopia_space/chat.py CHANGED
@@ -6,7 +6,7 @@ from typing import Literal
6
  import json
7
  from collections import defaultdict
8
  from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
9
- from sotopia_pi_generate import prepare_model, generate_action
10
  from sotopia_space.constants import MODEL_OPTIONS
11
 
12
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
 
6
  import json
7
  from collections import defaultdict
8
  from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
9
+ from sotopia_generate import prepare_model, generate_action
10
  from sotopia_space.constants import MODEL_OPTIONS
11
 
12
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
sotopia_space/constants.py CHANGED
@@ -5,7 +5,7 @@ MODEL_OPTIONS = [
5
  "cmu-lti/sotopia-pi-mistral-7b-BC_SR",
6
  "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
7
  "mistralai/Mistral-7B-Instruct-v0.1"
8
- # "mistralai/Mixtral-8x7B-Instruct-v0.1",
9
  # "togethercomputer/llama-2-7b-chat",
10
  # "togethercomputer/llama-2-70b-chat",
11
  # "togethercomputer/mpt-30b-chat",
 
5
  "cmu-lti/sotopia-pi-mistral-7b-BC_SR",
6
  "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
7
  "mistralai/Mistral-7B-Instruct-v0.1"
8
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1", # TODO: Add these model
9
  # "togethercomputer/llama-2-7b-chat",
10
  # "togethercomputer/llama-2-70b-chat",
11
  # "togethercomputer/mpt-30b-chat",