dvilasuero's picture
dvilasuero HF staff
Update load_data.py
9bfc13e
import sys
import time
import os
import pandas as pd
import requests
from datasets import load_dataset, concatenate_datasets
import argilla as rg
from argilla.listeners import listener
### Configuration section ###
# needed for pushing the validated data to HUB_DATASET_NAME
HF_TOKEN = os.environ.get("HF_TOKEN")
# The source dataset to read Alpaca translated examples
SOURCE_DATASET = "LEL-A/translated_german_alpaca"
# The name of the dataset in Argilla
RG_DATASET_NAME = "translated-german-alpaca"
# The name of the Hub dataset to push the validations every 20 min and keep the dataset synced
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation")
# The labels for the task (they can be extended if needed)
LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"]
@listener(
dataset=RG_DATASET_NAME,
query="status:Validated",
execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub`
)
def save_validated_to_hub(records, ctx):
if len(records) > 0:
ds = rg.DatasetForTextClassification(records=records).to_datasets()
if HF_TOKEN:
print("Pushing the dataset")
print(ds)
ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN)
else:
print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!")
else:
print("NO RECORDS found")
class LoadDatasets:
def __init__(self, api_key, workspace="team"):
rg.init(api_key=api_key, workspace=workspace)
@staticmethod
def load_somos():
# Leer el dataset del Hub
try:
print(f"Trying to sync with {HUB_DATASET_NAME}")
old_ds = load_dataset(HUB_DATASET_NAME, split="train")
except Exception as e:
print(f"Not possible to sync with {HUB_DATASET_NAME}")
print(e)
old_ds = None
print(f"Loading dataset: {SOURCE_DATASET}")
dataset = load_dataset(SOURCE_DATASET, split="train")
if old_ds:
print("Concatenating datasets")
dataset = concatenate_datasets([dataset, old_ds])
print("Concatenated dataset is:")
print(dataset)
dataset = dataset.remove_columns("metrics")
records = rg.DatasetForTextClassification.from_datasets(dataset)
settings = rg.TextClassificationSettings(
label_schema=LABELS
)
print(f"Configuring dataset: {RG_DATASET_NAME}")
rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team")
# Log the dataset
print(f"Logging dataset: {RG_DATASET_NAME}")
rg.log(
records,
name=RG_DATASET_NAME,
tags={"description": "Alpaca dataset to clean up"},
batch_size=200
)
# run listener
save_validated_to_hub.start()
if __name__ == "__main__":
API_KEY = sys.argv[1]
LOAD_DATASETS = sys.argv[2]
if LOAD_DATASETS.lower() == "none":
print("No datasets being loaded")
else:
while True:
try:
response = requests.get("http://0.0.0.0:6900/")
if response.status_code == 200:
ld = LoadDatasets(API_KEY)
ld.load_somos()
break
except requests.exceptions.ConnectionError:
pass
except Exception as e:
print(e)
time.sleep(10)
pass
time.sleep(5)
while True:
time.sleep(60)