kaggle-utils / kaggle_service.py
hahunavth
add cron server
c3ece9d
import json
import os
from typing import Callable, List, Union, Dict
# fake default account to use kaggle.api.kaggle_api_extended
os.environ['KAGGLE_USERNAME']=''
os.environ['KAGGLE_KEY']=''
from kaggle.api.kaggle_api_extended import KaggleApi
from kaggle.rest import ApiException
import shutil
import time
import threading
import copy
from logger import sheet_logger
def get_api():
api = KaggleApi()
api.authenticate()
return api
class KaggleApiWrapper(KaggleApi):
"""
Override KaggleApi.read_config_environment to use username and secret without environment variables
"""
def __init__(self, username, secret):
super().__init__()
self.username = username
self.secret = secret
def read_config_environment(self, config_data=None, quiet=False):
config = super().read_config_environment(config_data, quiet)
config['username'] = self.username
config['key'] = self.secret
# only work for pythonanyware
# config['proxy'] = "http://proxy.server:3128"
return config_data
def __del__(self):
# todo: fix bug when delete api
pass
# def get_accelerator_quota_with_http_info(self): # noqa: E501
# """
#
# This method makes a synchronous HTTP request by default. To make an
# asynchronous HTTP request, please pass async_req=True
# >>> thread = api.competitions_list_with_http_info(async_req=True)
# >>> result = thread.get()
#
# :param async_req bool
# :param str group: Filter competitions by a particular group
# :param str category: Filter competitions by a particular category
# :param str sort_by: Sort the results
# :param int page: Page number
# :param str search: Search terms
# :return: Result
# If the method is called asynchronously,
# returns the request thread.
# """
#
# all_params = [] # noqa: E501
# all_params.append('async_req')
# all_params.append('_return_http_data_only')
# all_params.append('_preload_content')
# all_params.append('_request_timeout')
#
# params = locals()
#
# collection_formats = {}
#
# path_params = {}
#
# query_params = []
# # if 'group' in params:
# # query_params.append(('group', params['group'])) # noqa: E501
# # if 'category' in params:
# # query_params.append(('category', params['category'])) # noqa: E501
# # if 'sort_by' in params:
# # query_params.append(('sortBy', params['sort_by'])) # noqa: E501
# # if 'page' in params:
# # query_params.append(('page', params['page'])) # noqa: E501
# # if 'search' in params:
# # query_params.append(('search', params['search'])) # noqa: E501
#
# header_params = {}
#
# form_params = []
# local_var_files = {}
#
# body_params = None
# # HTTP header `Accept`
# header_params['Accept'] = self.api_client.select_header_accept(
# ['application/json']) # noqa: E501
#
# # Authentication setting
# auth_settings = ['basicAuth'] # noqa: E501
#
# return self.api_client.call_api(
# 'i/kernels.KernelsService/GetAcceleratorQuotaStatistics', 'GET',
# # '/competitions/list', 'GET',
# path_params,
# query_params,
# header_params,
# body=body_params,
# post_params=form_params,
# files=local_var_files,
# response_type='Result', # noqa: E501
# auth_settings=auth_settings,
# async_req=params.get('async_req'),
# _return_http_data_only=params.get('_return_http_data_only'),
# _preload_content=params.get('_preload_content', True),
# _request_timeout=params.get('_request_timeout'),
# collection_formats=collection_formats)
#
# if __name__ == "__main__":
# api = KaggleApiWrapper('ha.vt194547@sis.hust.edu.vn', "c54e96568075fcc277bd10ba0e0a52b9")
# api.authenticate()
# print(api.get_accelerator_quota_with_http_info())
class ValidateException(Exception):
def __init__(self, message: str):
super(ValidateException, self).__init__(message)
@staticmethod
def from_api_exception(e: ApiException, kernel_slug: str):
return ValidateException(f"Error: {e.status} {e.reason} with notebook {kernel_slug}")
@staticmethod
def from_api_exception_list(el: List[ApiException], kernel_slug_list: List[str]):
message = f"Error: \n"
for e, k in zip(el, kernel_slug_list):
message = message + f"\t{e.status} {e.reason} with notebook {k}"
return ValidateException(message)
class KaggleNotebook:
def __init__(self, api: KaggleApi, kernel_slug: str, container_path: str = "./tmp", id=None):
"""
:param api: KaggleApi
:param kernel_slug: Notebook id, you can find it in the url of the notebook.
For example, `username/notebook-name-123456`
:param container_path: Path to the local folder where the notebook will be downloaded
"""
self.api = api
self.kernel_slug = kernel_slug
self.container_path = container_path
self.id = id
if self.id is None:
print(f"Warn: {self.__class__.__name__}.id is None")
def status(self) -> str or None:
"""
:return:
"running"
"cancelAcknowledged"
"queued": waiting for run
"error": when raise exception in notebook
Throw exception when failed
"""
res = self.api.kernels_status(self.kernel_slug)
print(f"Status: {res}")
if res is None:
if self.id is not None:
sheet_logger.update_job_status(self.id, notebook_status='None')
return None
if self.id is not None:
sheet_logger.update_job_status(self.id, notebook_status=res['status'])
return res['status']
def _get_local_nb_path(self) -> str:
return os.path.join(self.container_path, self.kernel_slug)
def pull(self, path=None) -> str or None:
"""
:param path:
:return:
:raises: ApiException if notebook not found or not share to user
"""
self._clean()
path = path or self._get_local_nb_path()
metadata_path = os.path.join(path, "kernel-metadata.json")
res = self.api.kernels_pull(self.kernel_slug, path=path, metadata=True, quiet=False)
if not os.path.exists(metadata_path):
print(f"Warn: Not found {metadata_path}. Clean {path}")
self._clean()
return None
return res
def push(self, path=None) -> str or None:
status = self.status()
if status in ['queued', 'running']:
print("Warn: Notebook is " + status + ". Skip push notebook!")
return None
self.api.kernels_push(path or self._get_local_nb_path())
time.sleep(1)
status = self.status()
return status
def _clean(self) -> None:
if os.path.exists(self._get_local_nb_path()):
shutil.rmtree(self._get_local_nb_path())
def get_metadata(self, path=None):
path = path or self._get_local_nb_path()
metadata_path = os.path.join(path, "kernel-metadata.json")
if not os.path.exists(metadata_path):
return None
return json.loads(open(metadata_path).read())
def check_nb_permission(self) -> Union[tuple[bool], tuple[bool, None]]:
status = self.status() # raise ApiException
if status is None:
return False, status
return True, status
def check_datasets_permission(self) -> bool:
meta = self.get_metadata()
if meta is None:
print("Warn: cannot get metadata. Pull and try again?")
dataset_sources = meta['dataset_sources']
ex_list = []
slugs = []
for dataset in dataset_sources:
try:
self.api.dataset_status(dataset)
except ApiException as e:
print(f"Error: {e.status} {e.reason} with dataset {dataset} in notebook {self.kernel_slug}")
ex_list.append(e)
slugs.append(self.kernel_slug)
# return False
if len(ex_list) > 0:
raise ValidateException.from_api_exception_list(ex_list, slugs)
return True
class AccountTransactionManager:
def __init__(self, acc_secret_dict: dict=None):
"""
:param acc_secret_dict: {username: secret}
"""
self.acc_secret_dict = acc_secret_dict
if self.acc_secret_dict is None:
self.acc_secret_dict = {}
# self.api_dict = {username: KaggleApiWrapper(username, secret) for username, secret in acc_secret_dict.items()}
# lock for each account to avoid concurrent use api
self.lock_dict = {username: False for username in self.acc_secret_dict.keys()}
self.state_lock = threading.Lock()
def _get_api(self, username: str) -> KaggleApiWrapper:
# return self.api_dict[username]
return KaggleApiWrapper(username, self.acc_secret_dict[username])
def _get_lock(self, username: str) -> bool:
return self.lock_dict[username]
def _set_lock(self, username: str, lock: bool) -> None:
self.lock_dict[username] = lock
def add_account(self, username, secret):
if username not in self.acc_secret_dict.keys():
self.state_lock.acquire()
self.acc_secret_dict[username] = secret
self.lock_dict[username] = False
self.state_lock.release()
def remove_account(self, username):
if username in self.acc_secret_dict.keys():
self.state_lock.acquire()
del self.acc_secret_dict[username]
del self.lock_dict[username]
self.state_lock.release()
else:
print(f"Warn: try to remove account not in the list: {username}, list: {self.acc_secret_dict.keys()}")
def get_unlocked_api_unblocking(self, username_list: List[str]) -> tuple[KaggleApiWrapper, Callable[[], None]]:
"""
:param username_list: list of username
:return: (api, release) where release is a function to release api
"""
while True:
print("get_unlocked_api_unblocking" + str(username_list))
for username in username_list:
self.state_lock.acquire()
if not self._get_lock(username):
self._set_lock(username, True)
api = self._get_api(username)
def release():
self.state_lock.acquire()
self._set_lock(username, False)
api.__del__()
self.state_lock.release()
self.state_lock.release()
return api, release
self.state_lock.release()
time.sleep(1)
class NbJob:
def __init__(self, acc_dict: dict, nb_slug: str, rerun_stt: List[str] = None, not_rerun_stt: List[str] = None, id=None):
"""
:param acc_dict:
:param nb_slug:
:param rerun_stt:
:param not_rerun_stt: If notebook status in this list, do not rerun it. (Note: do not add "queued", "running")
"""
self.rerun_stt = rerun_stt
if self.rerun_stt is None:
self.rerun_stt = ['complete']
self.not_rerun_stt = not_rerun_stt
if self.not_rerun_stt is None:
self.not_rerun_stt = ['queued', 'running', 'cancelAcknowledged']
assert "queued" in self.not_rerun_stt
assert "running" in self.not_rerun_stt
self.acc_dict = acc_dict
self.nb_slug = nb_slug
self.id = id
def get_acc_dict(self):
return self.acc_dict
def get_username_list(self):
return list(self.acc_dict.keys())
def is_valid_with_acc(self, api):
"""
:param api:
:return:
:raise: ValidationException
"""
notebook = KaggleNotebook(api, self.nb_slug, id=self.id)
try:
notebook.pull() # raise ApiException
stt, _ = notebook.check_nb_permission() # note: raise ApiException
stt = notebook.check_datasets_permission() # raise ValidationException
except ApiException as e:
raise ValidateException.from_api_exception(e, self.nb_slug)
# if not stt:
# return False
return True
def is_valid(self):
for username in self.acc_dict.keys():
secrets = self.acc_dict[username]
api = KaggleApiWrapper(username=username, secret=secrets)
api.authenticate()
if not self.is_valid_with_acc(api):
return False
return True
def acc_check_and_rerun_if_need(self, api: KaggleApi) -> bool:
"""
:return:
True if rerun success or notebook is running
False user does not have enough gpu quotas
:raises
Exception if setup error
"""
notebook = KaggleNotebook(api, self.nb_slug, "./tmp", id=self.id) # todo: change hardcode container_path here
notebook.pull()
assert notebook.check_datasets_permission(), f"User {api} does not have permission on datasets of notebook {self.nb_slug}"
success, status1 = notebook.check_nb_permission()
assert success, f"User {api} does not have permission on notebook {self.nb_slug}" # todo: using api.username
if status1 in self.rerun_stt:
status2 = notebook.push()
time.sleep(10)
status3 = notebook.status()
# if 3 times same stt -> acc out of quota
if status1 == status2 == status3:
sheet_logger.log(username=api.username, nb=self.nb_slug, log="Try but no effect. Seem account to be out of quota")
return False
if status3 in self.not_rerun_stt:
# sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is {status3} is in ignore status list {self.not_rerun_stt}, do nothing!")
sheet_logger.log(username=api.username, nb=self.nb_slug,
log=f"Schedule notebook successfully. Current status is '{status3}'")
return True
if status3 not in ["queued", "running"]:
# return False # todo: check when user is out of quota
print(f"Error: status is {status3}")
raise Exception("Setup exception")
return True
sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is '{status1}' is not in {self.rerun_stt}, do nothing!")
return True
@staticmethod
def from_dict(obj: dict, id=None):
return NbJob(acc_dict=obj['accounts'], nb_slug=obj['slug'], rerun_stt=obj.get('rerun_status'), not_rerun_stt=obj.get('not_rerun_stt'), id=id)
class KernelRerunService:
def __init__(self):
self.jobs: Dict[str, NbJob] = {}
self.acc_manager = AccountTransactionManager()
self.username2jobid = {}
self.jobid2username = {}
def add_job(self, nb_job: NbJob):
if nb_job.nb_slug in self.jobs.keys():
print("Warn: nb_job already in job list")
return
self.jobs[nb_job.nb_slug] = nb_job
self.jobid2username[nb_job.nb_slug] = nb_job.get_username_list()
for username in nb_job.get_username_list():
if username not in self.username2jobid.keys():
self.username2jobid[username] = []
self.acc_manager.add_account(username, nb_job.acc_dict[username])
self.username2jobid[username].append(nb_job.nb_slug)
def remove_job(self, nb_job):
if nb_job.nb_slug not in self.jobs.keys():
print("Warn: try to remove nb_job not in list")
return
username_list = self.jobid2username[nb_job.nb_slug]
username_list = [username for username in username_list if len(self.username2jobid[username]) == 1]
for username in username_list:
del self.username2jobid[username]
self.acc_manager.remove_account(username)
del self.jobs[nb_job.nb_slug]
del self.jobid2username[nb_job.nb_slug]
def validate_all(self):
for username in self.acc_manager.acc_secret_dict.keys():
api, release = self.acc_manager.get_unlocked_api_unblocking([username])
api.authenticate()
print(f"Using username: {api.username}")
for job in self.jobs.values():
ex_msg_list = []
if username in job.get_username_list():
print(f"Validate user: {username}, job: {job.nb_slug}")
try:
job.is_valid_with_acc(api)
except ValidateException as e:
print(f"Error: not valid")
a = f"Setup error: {username} does not have permission on notebook {job.nb_slug} or related datasets"
if job.id is not None: # if have id, write log
ex_msg_list.append(f"Account {username}\n" + str(e) + "\n")
else: # if not have id, raise
raise Exception(a)
if len(ex_msg_list) > 0:
sheet_logger.update_job_status(job.id, validate_status="\n".join(ex_msg_list))
else:
sheet_logger.update_job_status(job.id, validate_status="success")
release()
return True
def status_all(self):
for job in self.jobs.values():
print(f"Job: {job.nb_slug}")
api, release = self.acc_manager.get_unlocked_api_unblocking(job.get_username_list())
api.authenticate()
print(f"Using username: {api.username}")
notebook = KaggleNotebook(api, job.nb_slug, id=job.id)
print(f"Notebook: {notebook.kernel_slug}")
print(notebook.status())
release()
def run(self, nb_job: NbJob):
username_list = copy.copy(nb_job.get_username_list())
while len(username_list) > 0:
api, release = self.acc_manager.get_unlocked_api_unblocking(username_list)
api.authenticate()
print(f"Using username: {api.username}")
try:
result = nb_job.acc_check_and_rerun_if_need(api)
if result:
return True
except Exception as e:
print(e)
release()
break
if api.username in username_list:
username_list.remove(api.username)
release()
else:
release()
raise Exception("")
return False
def run_all(self):
for job in self.jobs.values():
success = self.run(job)
print(f"Job: {job.nb_slug} {success}")
# if __name__ == "__main__":
# service = KernelRerunService()
# files = os.listdir("./config")
# for file in files:
# if '.example' not in file:
# with open(os.path.join("./config", file), "r") as f:
# obj = json.loads(f.read())
# print(obj)
# service.add_job(NbJob.from_dict(obj))
# service.run_all()
# try:
# acc_secret_dict = {
# "hahunavth": "secret",
# "hahunavth2": "secret",
# "hahunavth3": "secret",
# "hahunavth4": "secret",
# "hahunavth5": "secret",
# }
# acc_manager = AccountTransactionManager(acc_secret_dict)
#
#
# def test1():
# username_list = ["hahunavth", "hahunavth2", "hahunavth3", "hahunavth4", "hahunavth5"]
# while len(username_list) > 0:
# api, release = acc_manager.get_unlocked_api_unblocking(username_list)
# print("test1 is using " + api.username)
# time.sleep(1)
# release()
# if api.username in username_list:
# username_list.remove(api.username)
# else:
# raise Exception("")
# print("test1 release " + api.username)
#
#
# def test2():
# username_list = ["hahunavth2", "hahunavth3", "hahunavth5"]
# while len(username_list) > 0:
# api, release = acc_manager.get_unlocked_api_unblocking(username_list)
# print("test2 is using " + api.username)
# time.sleep(3)
# release()
# if api.username in username_list:
# username_list.remove(api.username)
# else:
# raise Exception("")
# print("test2 release " + api.username)
#
#
# t1 = threading.Thread(target=test1)
# t2 = threading.Thread(target=test2)
# t1.start()
# t2.start()
# t1.join()
# t2.join()
#
# # kgapi = KaggleApiWrapper("hahunavth", "fb3d65ea4d06f91a83cf571e9a39d40d")
# # kgapi.authenticate()
# # # kgapi = get_api()
# # notebook = KaggleNotebook(kgapi, "hahunavth/ess-vlsp2023-denoising", "./tmp")
# # # print(notebook.pull())
# # # print(notebook.check_datasets_permission())
# # print(notebook.check_nb_permission())
# # # print(notebook.status())
# # # notebook.push()
# # # print(notebook.status())
# except ApiException as e:
# print(e.status)
# print(e.reason)
# raise e
# # 403 when nb not exists or not share to acc
# # 404 when push to unknow kenel_slug.username
# # 401 when invalid username, pass