File size: 5,281 Bytes
e1aa577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argilla as rg
import time
import pandas as pd
from argilla.client.singleton import active_client
from utils.config import Color
from dataset.base_dataset import DatasetBase
import json
import webbrowser
import base64

class ArgillaEstimator:
    """
    The ArgillaEstimator class is responsible to generate the GT for the dataset by using Argilla interface.
    In particular using the text classification mode.
    """
    def __init__(self, opt):
        """
        Initialize a new instance of the ArgillaEstimator class.
        """
        try:
            self.opt = opt
            rg.init(
                api_url=opt.api_url,
                api_key=opt.api_key,
                workspace=opt.workspace
            )
            self.time_interval = opt.time_interval
        except:
            raise Exception("Failed to connect to argilla, check connection details")

    @staticmethod
    def initialize_dataset(dataset_name: str, label_schema: set[str]):
        """
        Initialize a new dataset in the Argilla system
        :param dataset_name: The name of the dataset
        :param label_schema: The list of classes
        """
        try:
            settings = rg.TextClassificationSettings(label_schema=label_schema)
            rg.configure_dataset_settings(name=dataset_name, settings=settings)
        except:
            raise Exception("Failed to create dataset")

    @staticmethod
    def upload_missing_records(dataset_name: str, batch_id: int, batch_records: pd.DataFrame):
        """
        Update the Argilla dataset by adding missing records from batch_id that appears in batch_records
        :param dataset_name: The dataset name
        :param batch_id: The batch id
        :param batch_records: A dataframe of the batch records
        """
        #TODO: sort visualization according to batch_id descending
        query = "metadata.batch_id:{}".format(batch_id)
        result = rg.load(name=dataset_name, query=query)
        df = result.to_pandas()
        if len(df) == len(batch_records):
            return
        if df.empty:
            upload_df = batch_records
        else:
            merged_df = pd.merge(batch_records, df['text'], on='text', how='left', indicator=True)
            upload_df = merged_df[merged_df['_merge'] == 'left_only'].drop(columns=['_merge'])
        record_list = []
        for index, row in upload_df.iterrows():
            config = {'text': row['text'], 'metadata': {"batch_id": row['batch_id'], 'id': row['id']}, "id": row['id']}
            # if not (row[['prediction']].isnull().any()):
            #     config['prediction'] = row['prediction']  # TODO: fix it incorrect type!!!
            if not(row[['annotation']].isnull().any()):  # TODO: fix it incorrect type!!!
                config['annotation'] = row['annotation']
            record_list.append(rg.TextClassificationRecord(**config))
        rg.log(records=record_list, name=dataset_name)

    def calc_usage(self):
        """
        Dummy function to calculate the usage of the estimator
        """
        return 0

    def apply(self, dataset: DatasetBase, batch_id: int):
        """
        Apply the estimator on the dataset. The function enter to infinite loop until all the records are annotated.
        Then it update the dataset with all the annotations
        :param dataset: DatasetBase object, contains all the processed records
        :param batch_id: The batch id to annotate
        """
        current_api = active_client()
        try:
            rg_dataset = current_api.datasets.find_by_name(dataset.name)
        except:
            self.initialize_dataset(dataset.name, dataset.label_schema)
            rg_dataset = current_api.datasets.find_by_name(dataset.name)
        batch_records = dataset[batch_id]
        if batch_records.empty:
            return []
        self.upload_missing_records(dataset.name, batch_id, batch_records)
        data = {'metadata': {'batch_id': [str(batch_id)]}}
        json_data = json.dumps(data)
        encoded_bytes = base64.b64encode(json_data.encode('utf-8'))
        encoded_string = str(encoded_bytes, "utf-8")
        url_link = self.opt.api_url + '/datasets/' + self.opt.workspace + '/' \
                   + dataset.name + '?query=' + encoded_string
        print(f"{Color.GREEN}Waiting for annotations from batch {batch_id}:\n{url_link}{Color.END}")
        webbrowser.open(url_link)
        while True:
            query = "(status:Validated OR status:Discarded) AND metadata.batch_id:{}".format(batch_id)
            search_results = current_api.search.search_records(
                name=dataset.name,
                task=rg_dataset.task,
                size=0,
                query_text=query,
            )
            if search_results.total == len(batch_records):
                result = rg.load(name=dataset.name, query=query)
                df = result.to_pandas()[['text', 'annotation', 'metadata', 'status']]
                df["annotation"] = df.apply(lambda x: 'Discarded' if x['status']=='Discarded' else x['annotation'], axis=1)
                df = df.drop(columns=['status'])
                df['id'] = df.apply(lambda x: x['metadata']['id'], axis=1)
                return df
            time.sleep(self.time_interval)