File size: 1,936 Bytes
8f1a330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import datetime
import os
import subprocess
import sys
from timeit import default_timer as timer

sys.path.insert(0, os.getcwd())

start = timer()

result_filename_prefix = os.getenv("RESULT_FILENAME_PREFIX")
if not result_filename_prefix:
    now = datetime.datetime.now()
    result_filename_prefix = "Tune_{:%Y-%m-%d_%H-%M-%S}".format(now)

print(f"Result filename prefix: {result_filename_prefix}")
all_results_filename = f"./data/results/{result_filename_prefix}.csv"

repetition_penalty_delta = 0.02
repetition_penalty_end = 1.3
repetition_penalty = 1.0

repetition_penalty_start = os.getenv("REPETITION_PENALTY_START")
if repetition_penalty_start:
    repetition_penalty = float(repetition_penalty_start)
    print(f"Starting from RP: {repetition_penalty}")

while repetition_penalty <= repetition_penalty_end + 1e-5:
    new_env = os.environ.copy()

    repetition_penalty_str = f"{repetition_penalty:.3f}"
    new_env["HFTGI_RP"] = repetition_penalty_str
    new_env["HF_RP"] = repetition_penalty_str
    new_env["ML_RP"] = repetition_penalty_str
    new_env["SL_RP"] = repetition_penalty_str

    log_file = "./data/logs/{}_RP_{}.txt".format(
        result_filename_prefix, repetition_penalty_str
    )
    test_results_filename = "./data/results/{}_RP_{}.csv".format(
        result_filename_prefix, repetition_penalty_str
    )
    new_env["TEST_RESULTS_CSV_FILE"] = test_results_filename
    new_env["ALL_RESULTS_CSV_FILE"] = all_results_filename

    num_questions = os.getenv("NUM_QUESTIONS") or ""

    with open(log_file, "w") as f_obj:
        subprocess.run(
            f"python qa_chain_test.py {num_questions}",
            shell=True,
            env=new_env,
            stdout=f_obj,
            text=True,
        )

    repetition_penalty += repetition_penalty_delta

print(f"All results saved to {all_results_filename}")

end = timer()
total_time = end - start
print(f"Total time used: {total_time:.3f} s")