KingNish commited on
Commit
9379874
1 Parent(s): 1dd13ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -21,6 +21,18 @@ def transcribe(audio):
21
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def randomize_seed_fn(seed: int) -> int:
25
  seed = random.randint(0, 999999)
26
  return seed
@@ -33,18 +45,17 @@ Respond in a normal, conversational manner while being friendly and helpful.
33
  [USER]
34
  """
35
 
36
- def models(text, seed=42):
37
 
38
  seed = int(randomize_seed_fn(seed))
39
  generator = torch.Generator().manual_seed(seed)
40
 
41
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
42
 
43
  generate_kwargs = dict(
44
  max_new_tokens=300,
45
  seed=seed
46
- )
47
-
48
  formatted_prompt = system_instructions1 + text + "[JARVIS]"
49
  stream = client.text_generation(
50
  formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
@@ -52,7 +63,6 @@ def models(text, seed=42):
52
  for response in stream:
53
  if not response.token.text == "</s>":
54
  output += response.token.text
55
-
56
  return output
57
 
58
  async def respond(audio, model, seed):
@@ -72,6 +82,14 @@ DESCRIPTION = """ # <center><b>JARVIS⚡</b></center>
72
  with gr.Blocks(css="style.css") as demo:
73
  gr.Markdown(DESCRIPTION)
74
  with gr.Row():
 
 
 
 
 
 
 
 
75
  seed = gr.Slider(
76
  label="Seed",
77
  minimum=0,
@@ -89,8 +107,8 @@ with gr.Blocks(css="style.css") as demo:
89
  batch=True,
90
  max_batch_size=10,
91
  fn=respond,
92
- inputs=[input, seed],
93
  outputs=[output], live=True)
94
-
95
  if __name__ == "__main__":
96
  demo.queue(max_size=200).launch()
 
21
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
 
24
+ def client_fn(model):
25
+ if "Mixtral" in model:
26
+ return InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
27
+ elif "Llama" in model:
28
+ return InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
29
+ elif "Mistral" in model:
30
+ return InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
31
+ elif "Phi" in model:
32
+ return InferenceClient("microsoft/Phi-3-mini-4k-instruct")
33
+ else:
34
+ return InferenceClient("microsoft/Phi-3-mini-4k-instruct")
35
+
36
  def randomize_seed_fn(seed: int) -> int:
37
  seed = random.randint(0, 999999)
38
  return seed
 
45
  [USER]
46
  """
47
 
48
+ def models(text, model="Mixtral 8x7B", seed=42):
49
 
50
  seed = int(randomize_seed_fn(seed))
51
  generator = torch.Generator().manual_seed(seed)
52
 
53
+ client = client_fn(model)
54
 
55
  generate_kwargs = dict(
56
  max_new_tokens=300,
57
  seed=seed
58
+ )
 
59
  formatted_prompt = system_instructions1 + text + "[JARVIS]"
60
  stream = client.text_generation(
61
  formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
 
63
  for response in stream:
64
  if not response.token.text == "</s>":
65
  output += response.token.text
 
66
  return output
67
 
68
  async def respond(audio, model, seed):
 
82
  with gr.Blocks(css="style.css") as demo:
83
  gr.Markdown(DESCRIPTION)
84
  with gr.Row():
85
+ select = gr.Dropdown([ 'Mixtral 8x7B',
86
+ 'Llama 3 8B',
87
+ 'Mistral 7B v0.3',
88
+ 'Phi 3 mini',
89
+ ],
90
+ value="Mistral 7B v0.3",
91
+ label="Model"
92
+ )
93
  seed = gr.Slider(
94
  label="Seed",
95
  minimum=0,
 
107
  batch=True,
108
  max_batch_size=10,
109
  fn=respond,
110
+ inputs=[input, select, seed],
111
  outputs=[output], live=True)
112
+
113
  if __name__ == "__main__":
114
  demo.queue(max_size=200).launch()