open_dutch_llm_leaderboard / generate_overview_json.py
Bram Vanroy
update data collection script
6d7ff83
raw
history blame
1.62 kB
from pathlib import Path
import json
from pprint import pprint
from transformers import AutoModelForCausalLM
def get_num_parameters(model_name: str) -> int:
return AutoModelForCausalLM.from_pretrained(model_name).num_parameters()
def main():
evals_dir = Path(__file__).parent.joinpath("evals")
pf_overview = evals_dir.joinpath("models.json")
results = json.loads(pf_overview.read_text(encoding="utf-8")) if pf_overview.exists() else {}
for pfin in evals_dir.rglob("*.json"):
if pfin.stem == "models":
continue
short_name = pfin.stem.split("_")[2]
if short_name in results:
continue
data = json.loads(pfin.read_text(encoding="utf-8"))
if "config" not in data:
continue
config = data["config"]
if "model_args" not in config:
continue
model_args = dict(params.split("=") for params in config["model_args"].split(","))
if "pretrained" not in model_args:
continue
results[short_name] = {
"model_name": model_args["pretrained"],
"compute_dtype": model_args.get("dtype", None),
"quantization": None,
"num_parameters": get_num_parameters(model_args["pretrained"])
}
if "load_in_8bit" in model_args:
results[short_name]["quantization"] = "8-bit"
elif "load_in_4bit" in model_args:
results[short_name]["quantization"] = "4-bit"
pprint(results)
pf_overview.write_text(json.dumps(results, indent=4), encoding="utf-8")
if __name__ == '__main__':
main()