File size: 22,752 Bytes
c3ece9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8b18d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
import json
import os
from typing import Callable, List, Union, Dict

# fake default account to use kaggle.api.kaggle_api_extended
os.environ['KAGGLE_USERNAME']=''
os.environ['KAGGLE_KEY']=''

from kaggle.api.kaggle_api_extended import KaggleApi
from kaggle.rest import ApiException
import shutil
import time
import threading
import copy
from logger import sheet_logger


def get_api():
    api = KaggleApi()
    api.authenticate()
    return api


class KaggleApiWrapper(KaggleApi):
    """

    Override KaggleApi.read_config_environment to use username and secret without environment variables

    """

    def __init__(self, username, secret):
        super().__init__()
        self.username = username
        self.secret = secret

    def read_config_environment(self, config_data=None, quiet=False):
        config = super().read_config_environment(config_data, quiet)
        config['username'] = self.username
        config['key'] = self.secret
        # only work for pythonanyware
        # config['proxy'] = "http://proxy.server:3128"

        return config_data

    def __del__(self):
        # todo: fix bug when delete api
        pass

#     def get_accelerator_quota_with_http_info(self):  # noqa: E501
#         """
#
#         This method makes a synchronous HTTP request by default. To make an
#         asynchronous HTTP request, please pass async_req=True
#         >>> thread = api.competitions_list_with_http_info(async_req=True)
#         >>> result = thread.get()
#
#         :param async_req bool
#         :param str group: Filter competitions by a particular group
#         :param str category: Filter competitions by a particular category
#         :param str sort_by: Sort the results
#         :param int page: Page number
#         :param str search: Search terms
#         :return: Result
#                  If the method is called asynchronously,
#                  returns the request thread.
#         """
#
#         all_params = []  # noqa: E501
#         all_params.append('async_req')
#         all_params.append('_return_http_data_only')
#         all_params.append('_preload_content')
#         all_params.append('_request_timeout')
#
#         params = locals()
#
#         collection_formats = {}
#
#         path_params = {}
#
#         query_params = []
#         # if 'group' in params:
#         #     query_params.append(('group', params['group']))  # noqa: E501
#         # if 'category' in params:
#         #     query_params.append(('category', params['category']))  # noqa: E501
#         # if 'sort_by' in params:
#         #     query_params.append(('sortBy', params['sort_by']))  # noqa: E501
#         # if 'page' in params:
#         #     query_params.append(('page', params['page']))  # noqa: E501
#         # if 'search' in params:
#         #     query_params.append(('search', params['search']))  # noqa: E501
#
#         header_params = {}
#
#         form_params = []
#         local_var_files = {}
#
#         body_params = None
#         # HTTP header `Accept`
#         header_params['Accept'] = self.api_client.select_header_accept(
#             ['application/json'])  # noqa: E501
#
#         # Authentication setting
#         auth_settings = ['basicAuth']  # noqa: E501
#
#         return self.api_client.call_api(
#             'i/kernels.KernelsService/GetAcceleratorQuotaStatistics', 'GET',
#             # '/competitions/list', 'GET',
#             path_params,
#             query_params,
#             header_params,
#             body=body_params,
#             post_params=form_params,
#             files=local_var_files,
#             response_type='Result',  # noqa: E501
#             auth_settings=auth_settings,
#             async_req=params.get('async_req'),
#             _return_http_data_only=params.get('_return_http_data_only'),
#             _preload_content=params.get('_preload_content', True),
#             _request_timeout=params.get('_request_timeout'),
#             collection_formats=collection_formats)
#
# if __name__ == "__main__":
#     api = KaggleApiWrapper('ha.vt194547@sis.hust.edu.vn', "c54e96568075fcc277bd10ba0e0a52b9")
#     api.authenticate()
#     print(api.get_accelerator_quota_with_http_info())

class ValidateException(Exception):
    def __init__(self, message: str):
        super(ValidateException, self).__init__(message)

    @staticmethod
    def from_api_exception(e: ApiException, kernel_slug: str):
        return ValidateException(f"Error: {e.status} {e.reason} with notebook {kernel_slug}")

    @staticmethod
    def from_api_exception_list(el: List[ApiException], kernel_slug_list: List[str]):
        message = f"Error: \n"
        for e, k in zip(el, kernel_slug_list):
            message = message + f"\t{e.status} {e.reason} with notebook {k}"
        return ValidateException(message)


class KaggleNotebook:
    def __init__(self, api: KaggleApi, kernel_slug: str, container_path: str = "./tmp", id=None):
        """

        :param api: KaggleApi

        :param kernel_slug: Notebook id, you can find it in the url of the notebook.

                            For example, `username/notebook-name-123456`

        :param container_path: Path to the local folder where the notebook will be downloaded

        """
        self.api = api
        self.kernel_slug = kernel_slug
        self.container_path = container_path
        self.id = id
        if self.id is None:
            print(f"Warn: {self.__class__.__name__}.id is None")

    def status(self) -> str or None:
        """

        :return:

            "running"

            "cancelAcknowledged"

            "queued": waiting for run

            "error": when raise exception in notebook

        Throw exception when failed

        """
        res = self.api.kernels_status(self.kernel_slug)
        print(f"Status: {res}")
        if res is None:
            if self.id is not None:
                sheet_logger.update_job_status(self.id, notebook_status='None')
            return None
        if self.id is not None:
            sheet_logger.update_job_status(self.id, notebook_status=res['status'])
        return res['status']

    def _get_local_nb_path(self) -> str:
        return os.path.join(self.container_path, self.kernel_slug)

    def pull(self, path=None) -> str or None:
        """



        :param path:

        :return:

        :raises: ApiException if notebook not found or not share to user

        """
        self._clean()
        path = path or self._get_local_nb_path()
        metadata_path = os.path.join(path, "kernel-metadata.json")
        res = self.api.kernels_pull(self.kernel_slug, path=path, metadata=True, quiet=False)
        if not os.path.exists(metadata_path):
            print(f"Warn: Not found {metadata_path}. Clean {path}")
            self._clean()
            return None
        return res

    def push(self, path=None) -> str or None:
        status = self.status()
        if status in ['queued', 'running']:
            print("Warn: Notebook is " + status + ". Skip push notebook!")
            return None

        self.api.kernels_push(path or self._get_local_nb_path())
        time.sleep(1)
        status = self.status()
        return status

    def _clean(self) -> None:
        if os.path.exists(self._get_local_nb_path()):
            shutil.rmtree(self._get_local_nb_path())

    def get_metadata(self, path=None):
        path = path or self._get_local_nb_path()
        metadata_path = os.path.join(path, "kernel-metadata.json")
        if not os.path.exists(metadata_path):
            return None
        return json.loads(open(metadata_path).read())

    def check_nb_permission(self) -> Union[tuple[bool], tuple[bool, None]]:
        status = self.status() # raise ApiException
        if status is None:
            return False, status
        return True, status


    def check_datasets_permission(self) -> bool:
        meta = self.get_metadata()
        if meta is None:
            print("Warn: cannot get metadata. Pull and try again?")
        dataset_sources = meta['dataset_sources']
        ex_list = []
        slugs = []
        for dataset in dataset_sources:
            try:
                self.api.dataset_status(dataset)
            except ApiException as e:
                print(f"Error: {e.status} {e.reason} with dataset {dataset} in notebook {self.kernel_slug}")
                ex_list.append(e)
                slugs.append(self.kernel_slug)
                # return False
        if len(ex_list) > 0:
            raise ValidateException.from_api_exception_list(ex_list, slugs)
        return True


class AccountTransactionManager:
    def __init__(self, acc_secret_dict: dict=None):
        """

        :param acc_secret_dict: {username: secret}

        """
        self.acc_secret_dict = acc_secret_dict
        if self.acc_secret_dict is None:
            self.acc_secret_dict = {}
        # self.api_dict = {username: KaggleApiWrapper(username, secret) for username, secret in acc_secret_dict.items()}
        # lock for each account to avoid concurrent use api
        self.lock_dict = {username: False for username in self.acc_secret_dict.keys()}
        self.state_lock = threading.Lock()

    def _get_api(self, username: str) -> KaggleApiWrapper:
        # return self.api_dict[username]
        return KaggleApiWrapper(username, self.acc_secret_dict[username])

    def _get_lock(self, username: str) -> bool:
        return self.lock_dict[username]

    def _set_lock(self, username: str, lock: bool) -> None:
        self.lock_dict[username] = lock

    def add_account(self, username, secret):
        if username not in self.acc_secret_dict.keys():
            self.state_lock.acquire()
            self.acc_secret_dict[username] = secret
            self.lock_dict[username] = False
            self.state_lock.release()

    def remove_account(self, username):
        if username in self.acc_secret_dict.keys():
            self.state_lock.acquire()
            del self.acc_secret_dict[username]
            del self.lock_dict[username]
            self.state_lock.release()
        else:
            print(f"Warn: try to remove account not in the list: {username}, list: {self.acc_secret_dict.keys()}")

    def get_unlocked_api_unblocking(self, username_list: List[str]) -> tuple[KaggleApiWrapper, Callable[[], None]]:
        """

        :param username_list: list of username

        :return: (api, release) where release is a function to release api

        """
        while True:
            print("get_unlocked_api_unblocking" + str(username_list))
            for username in username_list:
                self.state_lock.acquire()
                if not self._get_lock(username):
                    self._set_lock(username, True)
                    api = self._get_api(username)

                    def release():
                        self.state_lock.acquire()
                        self._set_lock(username, False)
                        api.__del__()
                        self.state_lock.release()

                    self.state_lock.release()
                    return api, release
                self.state_lock.release()
            time.sleep(1)


class NbJob:
    def __init__(self, acc_dict: dict, nb_slug: str, rerun_stt: List[str] = None, not_rerun_stt: List[str] = None, id=None):
        """



        :param acc_dict:

        :param nb_slug:

        :param rerun_stt:

        :param not_rerun_stt: If notebook status in this list, do not rerun it. (Note: do not add "queued", "running")

        """
        self.rerun_stt = rerun_stt
        if self.rerun_stt is None:
            self.rerun_stt = ['complete']
        self.not_rerun_stt = not_rerun_stt

        if self.not_rerun_stt is None:
            self.not_rerun_stt = ['queued', 'running', 'cancelAcknowledged']
        assert "queued" in self.not_rerun_stt
        assert "running" in self.not_rerun_stt

        self.acc_dict = acc_dict
        self.nb_slug = nb_slug
        self.id = id

    def get_acc_dict(self):
        return self.acc_dict

    def get_username_list(self):
        return list(self.acc_dict.keys())

    def is_valid_with_acc(self, api):
        """

        :param api:

        :return:

        :raise: ValidationException

        """
        notebook = KaggleNotebook(api, self.nb_slug, id=self.id)

        try:
            notebook.pull()  # raise ApiException
            stt, _ = notebook.check_nb_permission()  # note: raise ApiException
            stt = notebook.check_datasets_permission()  # raise ValidationException
        except ApiException as e:
            raise ValidateException.from_api_exception(e, self.nb_slug)
        # if not stt:
        #     return False
        return True

    def is_valid(self):
        for username in self.acc_dict.keys():
            secrets = self.acc_dict[username]
            api = KaggleApiWrapper(username=username, secret=secrets)
            api.authenticate()
            if not self.is_valid_with_acc(api):
                return False
        return True

    def acc_check_and_rerun_if_need(self, api: KaggleApi) -> bool:
        """

        :return:

            True if rerun success or notebook is running

            False user does not have enough gpu quotas

        :raises

            Exception if setup error

        """
        notebook = KaggleNotebook(api, self.nb_slug, "./tmp", id=self.id)  # todo: change hardcode container_path here

        notebook.pull()
        assert notebook.check_datasets_permission(), f"User {api} does not have permission on datasets of notebook {self.nb_slug}"
        success, status1 = notebook.check_nb_permission()
        assert success, f"User {api} does not have permission on notebook {self.nb_slug}"  # todo: using api.username

        if status1 in self.rerun_stt:
            status2 = notebook.push()
            time.sleep(10)
            status3 = notebook.status()

            # if 3 times same stt -> acc out of quota
            if status1 == status2 == status3:
                sheet_logger.log(username=api.username, nb=self.nb_slug, log="Try but no effect. Seem account to be out of quota")
                return False
            if status3 in self.not_rerun_stt:
                # sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is {status3} is in ignore status list {self.not_rerun_stt}, do nothing!")
                sheet_logger.log(username=api.username, nb=self.nb_slug,
                                 log=f"Schedule notebook successfully. Current status is '{status3}'")
                return True
            if status3 not in ["queued", "running"]:
                # return False # todo: check when user is out of quota
                print(f"Error: status is {status3}")

                raise Exception("Setup exception")

            return True

        sheet_logger.log(username=api.username, nb=self.nb_slug, log=f"Notebook status is '{status1}' is not in {self.rerun_stt}, do nothing!")
        return True

    @staticmethod
    def from_dict(obj: dict, id=None):
        return NbJob(acc_dict=obj['accounts'], nb_slug=obj['slug'], rerun_stt=obj.get('rerun_status'), not_rerun_stt=obj.get('not_rerun_stt'), id=id)


class KernelRerunService:
    def __init__(self):
        self.jobs: Dict[str, NbJob] = {}
        self.acc_manager = AccountTransactionManager()
        self.username2jobid = {}
        self.jobid2username = {}

    def add_job(self, nb_job: NbJob):
        if nb_job.nb_slug in self.jobs.keys():
            print("Warn: nb_job already in job list")
            return
        self.jobs[nb_job.nb_slug] = nb_job
        self.jobid2username[nb_job.nb_slug] = nb_job.get_username_list()
        for username in nb_job.get_username_list():
            if username not in self.username2jobid.keys():
                self.username2jobid[username] = []
                self.acc_manager.add_account(username, nb_job.acc_dict[username])
            self.username2jobid[username].append(nb_job.nb_slug)

    def remove_job(self, nb_job):
        if nb_job.nb_slug not in self.jobs.keys():
            print("Warn: try to remove nb_job not in list")
            return
        username_list = self.jobid2username[nb_job.nb_slug]
        username_list = [username for username in username_list if len(self.username2jobid[username]) == 1]

        for username in username_list:
            del self.username2jobid[username]
            self.acc_manager.remove_account(username)
        del self.jobs[nb_job.nb_slug]
        del self.jobid2username[nb_job.nb_slug]

    def validate_all(self):
        for username in self.acc_manager.acc_secret_dict.keys():
            api, release = self.acc_manager.get_unlocked_api_unblocking([username])
            api.authenticate()
            print(f"Using username: {api.username}")

            for job in self.jobs.values():
                ex_msg_list = []
                if username in job.get_username_list():
                    print(f"Validate user: {username}, job: {job.nb_slug}")

                    try:
                        job.is_valid_with_acc(api)
                    except ValidateException as e:
                        print(f"Error: not valid")
                        a = f"Setup error: {username} does not have permission on notebook {job.nb_slug} or related datasets"
                        if job.id is not None: # if have id, write log
                            ex_msg_list.append(f"Account {username}\n" + str(e) + "\n")
                        else: # if not have id, raise
                            raise Exception(a)
                if len(ex_msg_list) > 0:
                    sheet_logger.update_job_status(job.id, validate_status="\n".join(ex_msg_list))
                else:
                    sheet_logger.update_job_status(job.id, validate_status="success")

            release()
        return True

    def status_all(self):
        for job in self.jobs.values():
            print(f"Job: {job.nb_slug}")
            api, release = self.acc_manager.get_unlocked_api_unblocking(job.get_username_list())
            api.authenticate()
            print(f"Using username: {api.username}")
            notebook = KaggleNotebook(api, job.nb_slug, id=job.id)
            print(f"Notebook: {notebook.kernel_slug}")
            print(notebook.status())
            release()

    def run(self, nb_job: NbJob):
        username_list = copy.copy(nb_job.get_username_list())
        while len(username_list) > 0:
            api, release = self.acc_manager.get_unlocked_api_unblocking(username_list)
            api.authenticate()
            print(f"Using username: {api.username}")

            try:
                result = nb_job.acc_check_and_rerun_if_need(api)
                if result:
                    return True
            except Exception as e:
                print(e)
                release()
                break

            if api.username in username_list:
                username_list.remove(api.username)
                release()
            else:
                release()
                raise Exception("")
        return False

    def run_all(self):
        for job in self.jobs.values():
            success = self.run(job)
            print(f"Job: {job.nb_slug} {success}")


# if __name__ == "__main__":
    # service = KernelRerunService()
    # files = os.listdir("./config")
    # for file in files:
    #     if '.example' not in file:
    #         with open(os.path.join("./config", file), "r") as f:
    #             obj = json.loads(f.read())
    #             print(obj)
    #             service.add_job(NbJob.from_dict(obj))
    # service.run_all()

    # try:
    #     acc_secret_dict = {
    #         "hahunavth": "secret",
    #         "hahunavth2": "secret",
    #         "hahunavth3": "secret",
    #         "hahunavth4": "secret",
    #         "hahunavth5": "secret",
    #     }
    #     acc_manager = AccountTransactionManager(acc_secret_dict)
    #
    #
    #     def test1():
    #         username_list = ["hahunavth", "hahunavth2", "hahunavth3", "hahunavth4", "hahunavth5"]
    #         while len(username_list) > 0:
    #             api, release = acc_manager.get_unlocked_api_unblocking(username_list)
    #             print("test1 is using " + api.username)
    #             time.sleep(1)
    #             release()
    #             if api.username in username_list:
    #                 username_list.remove(api.username)
    #             else:
    #                 raise Exception("")
    #             print("test1 release " + api.username)
    #
    #
    #     def test2():
    #         username_list = ["hahunavth2", "hahunavth3", "hahunavth5"]
    #         while len(username_list) > 0:
    #             api, release = acc_manager.get_unlocked_api_unblocking(username_list)
    #             print("test2 is using " + api.username)
    #             time.sleep(3)
    #             release()
    #             if api.username in username_list:
    #                 username_list.remove(api.username)
    #             else:
    #                 raise Exception("")
    #             print("test2 release " + api.username)
    #
    #
    #     t1 = threading.Thread(target=test1)
    #     t2 = threading.Thread(target=test2)
    #     t1.start()
    #     t2.start()
    #     t1.join()
    #     t2.join()
    #
    #     # kgapi = KaggleApiWrapper("hahunavth", "fb3d65ea4d06f91a83cf571e9a39d40d")
    #     # kgapi.authenticate()
    #     # # kgapi = get_api()
    #     # notebook = KaggleNotebook(kgapi, "hahunavth/ess-vlsp2023-denoising", "./tmp")
    #     # # print(notebook.pull())
    #     # # print(notebook.check_datasets_permission())
    #     # print(notebook.check_nb_permission())
    #     # # print(notebook.status())
    #     # # notebook.push()
    #     # # print(notebook.status())
    # except ApiException as e:
    #     print(e.status)
    #     print(e.reason)
    #     raise e
    #     # 403 when nb not exists or not share to acc
    #     # 404 when push to unknow kenel_slug.username
    #     # 401 when invalid username, pass