SkalskiP commited on
Commit
42be940
1 Parent(s): 2c0f2ed

Added `top_k` & `top_p`

Browse files
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -4,7 +4,8 @@ import google.generativeai as genai
4
  import gradio as gr
5
  from PIL import Image
6
 
7
- TITLE = """<h1 align="center">Gemini Pro and Pro Vision via API 🚀</h1>"""
 
8
  DUPLICATE = """
9
  <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
10
  <a href="https://huggingface.co/spaces/SkalskiP/ChatGemini?duplicate=true">
@@ -32,6 +33,8 @@ def predict(
32
  temperature: float,
33
  max_output_tokens: int,
34
  stop_sequences: str,
 
 
35
  chatbot: List[Tuple[str, str]]
36
  ) -> Tuple[str, List[Tuple[str, str]]]:
37
  if not google_key:
@@ -43,7 +46,9 @@ def predict(
43
  generation_config = genai.types.GenerationConfig(
44
  temperature=temperature,
45
  max_output_tokens=max_output_tokens,
46
- stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences))
 
 
47
 
48
  if image_prompt is None:
49
  model = genai.GenerativeModel('gemini-pro')
@@ -110,6 +115,32 @@ stop_sequences_component = gr.Textbox(
110
  "response generation if the model encounters it. The sequence is not included "
111
  "as part of the response. You can add up to five stop sequences."
112
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  inputs = [
115
  google_key_component,
@@ -118,11 +149,14 @@ inputs = [
118
  temperature_component,
119
  max_output_tokens_component,
120
  stop_sequences_component,
 
 
121
  chatbot_component
122
  ]
123
 
124
  with gr.Blocks() as demo:
125
  gr.HTML(TITLE)
 
126
  gr.HTML(DUPLICATE)
127
  with gr.Column():
128
  google_key_component.render()
@@ -135,6 +169,9 @@ with gr.Blocks() as demo:
135
  temperature_component.render()
136
  max_output_tokens_component.render()
137
  stop_sequences_component.render()
 
 
 
138
 
139
  run_button_component.click(
140
  fn=predict,
 
4
  import gradio as gr
5
  from PIL import Image
6
 
7
+ TITLE = """<h1 align="center">Gemini Playground 💬</h1>"""
8
+ SUBTITLE = """<h2 align="center">Play with Gemini Pro and Gemini Pro Vision API</h2>"""
9
  DUPLICATE = """
10
  <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
11
  <a href="https://huggingface.co/spaces/SkalskiP/ChatGemini?duplicate=true">
 
33
  temperature: float,
34
  max_output_tokens: int,
35
  stop_sequences: str,
36
+ top_k: int,
37
+ top_p: float,
38
  chatbot: List[Tuple[str, str]]
39
  ) -> Tuple[str, List[Tuple[str, str]]]:
40
  if not google_key:
 
46
  generation_config = genai.types.GenerationConfig(
47
  temperature=temperature,
48
  max_output_tokens=max_output_tokens,
49
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
50
+ top_k=top_k,
51
+ top_p=top_p)
52
 
53
  if image_prompt is None:
54
  model = genai.GenerativeModel('gemini-pro')
 
115
  "response generation if the model encounters it. The sequence is not included "
116
  "as part of the response. You can add up to five stop sequences."
117
  ))
118
+ top_k_component = gr.Slider(
119
+ minimum=1,
120
+ maximum=40,
121
+ value=32,
122
+ step=1,
123
+ label="Top-K",
124
+ info=(
125
+ "Top-k changes how the model selects tokens for output. A top-k of 1 means the "
126
+ "selected token is the most probable among all tokens in the model’s "
127
+ "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
128
+ "next token is selected from among the 3 most probable tokens (using "
129
+ "temperature)."
130
+ ))
131
+ top_p_component = gr.Slider(
132
+ minimum=0,
133
+ maximum=1,
134
+ value=1,
135
+ step=0.01,
136
+ label="Top-P",
137
+ info=(
138
+ "Top-p changes how the model selects tokens for output. Tokens are selected "
139
+ "from most probable to least until the sum of their probabilities equals the "
140
+ "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
141
+ "and .1 and the top-p value is .5, then the model will select either A or B as "
142
+ "the next token (using temperature). "
143
+ ))
144
 
145
  inputs = [
146
  google_key_component,
 
149
  temperature_component,
150
  max_output_tokens_component,
151
  stop_sequences_component,
152
+ top_k_component,
153
+ top_p_component,
154
  chatbot_component
155
  ]
156
 
157
  with gr.Blocks() as demo:
158
  gr.HTML(TITLE)
159
+ gr.HTML(SUBTITLE)
160
  gr.HTML(DUPLICATE)
161
  with gr.Column():
162
  google_key_component.render()
 
169
  temperature_component.render()
170
  max_output_tokens_component.render()
171
  stop_sequences_component.render()
172
+ with gr.Accordion("Advanced", open=False):
173
+ top_k_component.render()
174
+ top_p_component.render()
175
 
176
  run_button_component.click(
177
  fn=predict,