File size: 3,494 Bytes
9aba307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2edcb
 
 
 
 
9aba307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""main.py
Main entry point for training
"""

import sys
import tempfile
import warnings
from argparse import Namespace
from pathlib import Path

import mlflow
from datasets import load_dataset

from config import config
from config.config import logger
from yomikata import utils
from yomikata.dbert import dBert


# MLFlow model registry
mlflow.set_tracking_uri("file://" + str(config.RUN_REGISTRY.absolute()))


warnings.filterwarnings("ignore")


def train_model(
    model_name: "dBert",
    dataset_name: str = "",
    experiment_name: str = "baselines",
    run_name: str = "dbert-default",
    training_args: dict = {},
) -> None:
    """Train a model given arguments.

    Args:
        dataset_name (str): name of the dataset to be trained on. Defaults to the full dataset.
        args_fp (str): location of args.
        experiment_name (str): name of experiment.
        run_name (str): name of specific run in experiment.
    """

    mlflow.set_experiment(experiment_name=experiment_name)
    with mlflow.start_run(run_name=run_name):

        run_id = mlflow.active_run().info.run_id
        logger.info(f"Run ID: {run_id}")

        experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
        artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")

        # Initialize the model
        if model_name == "dBert":
            reader = dBert(reinitialize=True, artifacts_dir=artifacts_dir)
        else:
            raise ValueError("model_name must be dBert for now")

        # Load train val test data
        dataset = load_dataset(
            "csv",
            data_files={
                "train": str(
                    Path(config.TRAIN_DATA_DIR, "train_" + dataset_name + ".csv")
                ),
                "val": str(Path(config.VAL_DATA_DIR, "val_" + dataset_name + ".csv")),
                "test": str(
                    Path(config.TEST_DATA_DIR, "test_" + dataset_name + ".csv")
                ),
            },
        )

        # Train
        training_performance = reader.train(dataset, training_args=training_args)

        # general_performance = evaluate.evaluate(reader, max_evals=20)

        with tempfile.TemporaryDirectory() as dp:
            # reader.save(dp)
            # utils.save_dict(general_performance, Path(dp, "general_performance.json"))
            utils.save_dict(training_performance, Path(dp, "training_performance.json"))
            mlflow.log_artifacts(dp)


def get_artifacts_dir_from_run(run_id: str):
    """Load artifacts directory for a given run_id.

    Args:
        run_id (str): id of run to load artifacts from.

    Returns:
        Path: path to artifacts directory.

    """

    # Locate specifics artifacts directory
    experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
    artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")

    return artifacts_dir


if __name__ == "__main__":

    # get args filepath from input
    args_fp = sys.argv[1]

    # load the args_file
    args = Namespace(**utils.load_dict(filepath=args_fp)).__dict__

    # pop meta variables
    model_name = args.pop("model")
    dataset_name = args.pop("dataset")
    experiment_name = args.pop("experiment")
    run_name = args.pop("run")

    # Perform training
    train_model(
        model_name=model_name,
        dataset_name=dataset_name,
        experiment_name=experiment_name,
        run_name=run_name,
        training_args=args,
    )