KingNish commited on
Commit
ad83e99
1 Parent(s): 9863688

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +12 -13
chatbot.py CHANGED
@@ -315,7 +315,7 @@ def model_inference(
315
  temperature,
316
  max_new_tokens,
317
  repetition_penalty,
318
- min_p,
319
  web_search,
320
  ):
321
  # Define generation_args at the beginning of the function
@@ -332,7 +332,6 @@ def model_inference(
332
  generate_kwargs = dict(
333
  max_new_tokens=4000,
334
  do_sample=True,
335
- min_p=0.08,
336
  )
337
  # Format the prompt for the language model
338
  formatted_prompt = format_prompt(
@@ -352,7 +351,6 @@ def model_inference(
352
  generate_kwargs = dict(
353
  max_new_tokens=5000,
354
  do_sample=True,
355
- min_p=0.08,
356
  )
357
  # Format the prompt for the language model
358
  formatted_prompt = format_prompt(
@@ -391,15 +389,16 @@ def model_inference(
391
  }
392
  assert decoding_strategy in [
393
  "Greedy",
394
- "Min P Sampling",
395
  ]
396
 
397
  if decoding_strategy == "Greedy":
398
  generation_args["do_sample"] = False
399
- elif decoding_strategy == "Min P Sampling":
400
  generation_args["temperature"] = temperature
401
  generation_args["do_sample"] = True
402
- generation_args["min_p"] = min_p
 
403
  (
404
  resulting_text,
405
  resulting_images,
@@ -441,7 +440,7 @@ FEATURES = datasets.Features(
441
  "temperature": datasets.Value("float32"),
442
  "max_new_tokens": datasets.Value("int32"),
443
  "repetition_penalty": datasets.Value("float32"),
444
- "min_p": datasets.Value("int32"),
445
  }
446
  )
447
 
@@ -466,9 +465,9 @@ repetition_penalty = gr.Slider(
466
  decoding_strategy = gr.Radio(
467
  [
468
  "Greedy",
469
- "Min P Sampling",
470
  ],
471
- value="Min P Sampling",
472
  label="Decoding strategy",
473
  interactive=True,
474
  info="Higher values are equivalent to sampling more low-probability tokens.",
@@ -483,14 +482,14 @@ temperature = gr.Slider(
483
  label="Sampling temperature",
484
  info="Higher values will produce more diverse outputs.",
485
  )
486
- min_p = gr.Slider(
487
  minimum=0.01,
488
- maximum=0.49,
489
- value=0.08,
490
  step=0.01,
491
  visible=True,
492
  interactive=True,
493
- label="Min P",
494
  info="Higher values are equivalent to sampling more low-probability tokens.",
495
  )
496
 
 
315
  temperature,
316
  max_new_tokens,
317
  repetition_penalty,
318
+ top_p,
319
  web_search,
320
  ):
321
  # Define generation_args at the beginning of the function
 
332
  generate_kwargs = dict(
333
  max_new_tokens=4000,
334
  do_sample=True,
 
335
  )
336
  # Format the prompt for the language model
337
  formatted_prompt = format_prompt(
 
351
  generate_kwargs = dict(
352
  max_new_tokens=5000,
353
  do_sample=True,
 
354
  )
355
  # Format the prompt for the language model
356
  formatted_prompt = format_prompt(
 
389
  }
390
  assert decoding_strategy in [
391
  "Greedy",
392
+ "Top P Sampling",
393
  ]
394
 
395
  if decoding_strategy == "Greedy":
396
  generation_args["do_sample"] = False
397
+ elif decoding_strategy == "Top P Sampling":
398
  generation_args["temperature"] = temperature
399
  generation_args["do_sample"] = True
400
+ generation_args["top_p"] = top_p
401
+ # Creating model inputs
402
  (
403
  resulting_text,
404
  resulting_images,
 
440
  "temperature": datasets.Value("float32"),
441
  "max_new_tokens": datasets.Value("int32"),
442
  "repetition_penalty": datasets.Value("float32"),
443
+ "top_p": datasets.Value("int32"),
444
  }
445
  )
446
 
 
465
  decoding_strategy = gr.Radio(
466
  [
467
  "Greedy",
468
+ "Top P Sampling",
469
  ],
470
+ value="Top P Sampling",
471
  label="Decoding strategy",
472
  interactive=True,
473
  info="Higher values are equivalent to sampling more low-probability tokens.",
 
482
  label="Sampling temperature",
483
  info="Higher values will produce more diverse outputs.",
484
  )
485
+ top_p = gr.Slider(
486
  minimum=0.01,
487
+ maximum=0.99,
488
+ value=0.9,
489
  step=0.01,
490
  visible=True,
491
  interactive=True,
492
+ label="Top P",
493
  info="Higher values are equivalent to sampling more low-probability tokens.",
494
  )
495