dh-mc's picture
latest code/data
8f1a330
raw
history blame
No virus
1.94 kB
import datetime
import os
import subprocess
import sys
from timeit import default_timer as timer
sys.path.insert(0, os.getcwd())
start = timer()
result_filename_prefix = os.getenv("RESULT_FILENAME_PREFIX")
if not result_filename_prefix:
now = datetime.datetime.now()
result_filename_prefix = "Tune_{:%Y-%m-%d_%H-%M-%S}".format(now)
print(f"Result filename prefix: {result_filename_prefix}")
all_results_filename = f"./data/results/{result_filename_prefix}.csv"
repetition_penalty_delta = 0.02
repetition_penalty_end = 1.3
repetition_penalty = 1.0
repetition_penalty_start = os.getenv("REPETITION_PENALTY_START")
if repetition_penalty_start:
repetition_penalty = float(repetition_penalty_start)
print(f"Starting from RP: {repetition_penalty}")
while repetition_penalty <= repetition_penalty_end + 1e-5:
new_env = os.environ.copy()
repetition_penalty_str = f"{repetition_penalty:.3f}"
new_env["HFTGI_RP"] = repetition_penalty_str
new_env["HF_RP"] = repetition_penalty_str
new_env["ML_RP"] = repetition_penalty_str
new_env["SL_RP"] = repetition_penalty_str
log_file = "./data/logs/{}_RP_{}.txt".format(
result_filename_prefix, repetition_penalty_str
)
test_results_filename = "./data/results/{}_RP_{}.csv".format(
result_filename_prefix, repetition_penalty_str
)
new_env["TEST_RESULTS_CSV_FILE"] = test_results_filename
new_env["ALL_RESULTS_CSV_FILE"] = all_results_filename
num_questions = os.getenv("NUM_QUESTIONS") or ""
with open(log_file, "w") as f_obj:
subprocess.run(
f"python qa_chain_test.py {num_questions}",
shell=True,
env=new_env,
stdout=f_obj,
text=True,
)
repetition_penalty += repetition_penalty_delta
print(f"All results saved to {all_results_filename}")
end = timer()
total_time = end - start
print(f"Total time used: {total_time:.3f} s")