Spaces:
Paused
Paused
..
Browse files- spaces/spaces/__init__.py +0 -23
- spaces/spaces/config.py +0 -29
- spaces/spaces/gradio.py +0 -55
- spaces/spaces/utils.py +0 -73
- spaces/spaces/zero/__init__.py +0 -12
- spaces/spaces/zero/api.py +0 -154
- spaces/spaces/zero/bitsandbytes.py +0 -135
- spaces/spaces/zero/client.py +0 -175
- spaces/spaces/zero/decorator.py +0 -117
- spaces/spaces/zero/gradio.py +0 -108
- spaces/spaces/zero/torch.py +0 -279
- spaces/spaces/zero/tqdm.py +0 -14
- spaces/spaces/zero/types.py +0 -44
- spaces/spaces/zero/utils.py +0 -44
- spaces/spaces/zero/wrappers.py +0 -347
spaces/spaces/__init__.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
|
4 |
-
import sys
|
5 |
-
|
6 |
-
if sys.version_info.minor < 8: # pragma: no cover
|
7 |
-
raise RuntimeError("Importing PySpaces requires Python 3.8+")
|
8 |
-
|
9 |
-
|
10 |
-
from .zero.decorator import GPU
|
11 |
-
from .zero.torch import disable_cuda_intercept
|
12 |
-
from .gradio import gradio_auto_wrap
|
13 |
-
from .gradio import disable_gradio_auto_wrap
|
14 |
-
from .gradio import enable_gradio_auto_wrap
|
15 |
-
|
16 |
-
|
17 |
-
__all__ = [
|
18 |
-
'GPU',
|
19 |
-
'disable_cuda_intercept',
|
20 |
-
'gradio_auto_wrap',
|
21 |
-
'disable_gradio_auto_wrap',
|
22 |
-
'enable_gradio_auto_wrap',
|
23 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/config.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import os
|
6 |
-
|
7 |
-
from .utils import boolean
|
8 |
-
|
9 |
-
|
10 |
-
class Settings:
|
11 |
-
def __init__(self):
|
12 |
-
self.zero_gpu = boolean(
|
13 |
-
os.getenv('SPACES_ZERO_GPU'))
|
14 |
-
self.zero_device_api_url = (
|
15 |
-
os.getenv('SPACES_ZERO_DEVICE_API_URL'))
|
16 |
-
self.gradio_auto_wrap = boolean(
|
17 |
-
os.getenv('SPACES_GRADIO_AUTO_WRAP'))
|
18 |
-
self.zero_patch_torch_device = boolean(
|
19 |
-
os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
|
20 |
-
|
21 |
-
|
22 |
-
Config = Settings()
|
23 |
-
|
24 |
-
|
25 |
-
if Config.zero_gpu:
|
26 |
-
assert Config.zero_device_api_url is not None, (
|
27 |
-
'SPACES_ZERO_DEVICE_API_URL env must be set '
|
28 |
-
'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
|
29 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/gradio.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
from typing import Callable
|
6 |
-
from typing import Generator
|
7 |
-
from typing import TypeVar
|
8 |
-
from typing import overload
|
9 |
-
from typing_extensions import ParamSpec
|
10 |
-
|
11 |
-
from .config import Config
|
12 |
-
from .zero.decorator import GPU
|
13 |
-
|
14 |
-
|
15 |
-
Param = ParamSpec('Param')
|
16 |
-
Res = TypeVar('Res')
|
17 |
-
|
18 |
-
|
19 |
-
gradio_auto_wrap_enabled = Config.gradio_auto_wrap
|
20 |
-
|
21 |
-
|
22 |
-
def disable_gradio_auto_wrap():
|
23 |
-
global gradio_auto_wrap_enabled
|
24 |
-
gradio_auto_wrap_enabled = False
|
25 |
-
|
26 |
-
def enable_gradio_auto_wrap():
|
27 |
-
global gradio_auto_wrap_enabled
|
28 |
-
gradio_auto_wrap_enabled = True
|
29 |
-
|
30 |
-
|
31 |
-
@overload
|
32 |
-
def gradio_auto_wrap(
|
33 |
-
task:
|
34 |
-
Callable[Param, Res],
|
35 |
-
) -> Callable[Param, Res]:
|
36 |
-
...
|
37 |
-
@overload
|
38 |
-
def gradio_auto_wrap(
|
39 |
-
task:
|
40 |
-
None,
|
41 |
-
) -> None:
|
42 |
-
...
|
43 |
-
def gradio_auto_wrap(
|
44 |
-
task:
|
45 |
-
Callable[Param, Res]
|
46 |
-
| None,
|
47 |
-
) -> (Callable[Param, Res]
|
48 |
-
| None):
|
49 |
-
"""
|
50 |
-
"""
|
51 |
-
if not gradio_auto_wrap_enabled:
|
52 |
-
return task
|
53 |
-
if not callable(task):
|
54 |
-
return task
|
55 |
-
return GPU(task) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/utils.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import sys
|
6 |
-
from functools import lru_cache as cache
|
7 |
-
from functools import partial
|
8 |
-
|
9 |
-
import multiprocessing
|
10 |
-
from multiprocessing.queues import SimpleQueue as _SimpleQueue
|
11 |
-
from pathlib import Path
|
12 |
-
from pickle import PicklingError
|
13 |
-
from typing import Callable
|
14 |
-
from typing import TypeVar
|
15 |
-
|
16 |
-
|
17 |
-
GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
|
18 |
-
|
19 |
-
|
20 |
-
T = TypeVar('T')
|
21 |
-
|
22 |
-
|
23 |
-
@cache
|
24 |
-
def self_cgroup_device_path() -> str:
|
25 |
-
cgroup_content = Path('/proc/self/cgroup').read_text()
|
26 |
-
for line in cgroup_content.strip().split('\n'):
|
27 |
-
contents = line.split(':devices:')
|
28 |
-
if len(contents) != 2:
|
29 |
-
continue # pragma: no cover
|
30 |
-
return contents[1]
|
31 |
-
raise Exception # pragma: no cover
|
32 |
-
|
33 |
-
|
34 |
-
if sys.version_info.minor < 9: # pragma: no cover
|
35 |
-
_SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
|
36 |
-
|
37 |
-
class SimpleQueue(_SimpleQueue[T]):
|
38 |
-
def __init__(self, *args):
|
39 |
-
super().__init__(*args, ctx=multiprocessing.get_context('fork'))
|
40 |
-
def put(self, obj: T):
|
41 |
-
try:
|
42 |
-
super().put(obj)
|
43 |
-
except PicklingError:
|
44 |
-
raise # pragma: no cover
|
45 |
-
# https://bugs.python.org/issue29187
|
46 |
-
except Exception as e:
|
47 |
-
message = str(e)
|
48 |
-
if not "pickle" in message:
|
49 |
-
raise # pragma: no cover
|
50 |
-
raise PicklingError(message)
|
51 |
-
def close(self): # Python 3.8 static typing trick
|
52 |
-
super().close() # type: ignore
|
53 |
-
|
54 |
-
|
55 |
-
def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
|
56 |
-
def drop(*args):
|
57 |
-
return fn()
|
58 |
-
return drop
|
59 |
-
|
60 |
-
|
61 |
-
def boolean(value: str | None) -> bool:
|
62 |
-
return value is not None and value.lower() in ("1", "t", "true")
|
63 |
-
|
64 |
-
|
65 |
-
def gradio_request_var():
|
66 |
-
try:
|
67 |
-
from gradio.context import LocalContext
|
68 |
-
except ImportError: # pragma: no cover
|
69 |
-
raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
|
70 |
-
return LocalContext.request
|
71 |
-
|
72 |
-
|
73 |
-
debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
|
4 |
-
from ..config import Config
|
5 |
-
from . import torch
|
6 |
-
|
7 |
-
if Config.zero_gpu:
|
8 |
-
if torch.is_in_bad_fork():
|
9 |
-
raise RuntimeError(
|
10 |
-
"CUDA has been initialized before importing the `spaces` package"
|
11 |
-
)
|
12 |
-
torch.patch() # pragma: no cover
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/api.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Synced with huggingface/pyspaces:spaces/zero/api.py
|
3 |
-
"""
|
4 |
-
from __future__ import annotations
|
5 |
-
|
6 |
-
from datetime import timedelta
|
7 |
-
from typing import Any
|
8 |
-
from typing import Generator
|
9 |
-
from typing import Literal
|
10 |
-
from typing import NamedTuple
|
11 |
-
from typing import Optional
|
12 |
-
from typing import overload
|
13 |
-
|
14 |
-
import httpx
|
15 |
-
from pydantic import BaseModel
|
16 |
-
from typing_extensions import assert_never
|
17 |
-
|
18 |
-
|
19 |
-
AllowToken = str
|
20 |
-
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
|
21 |
-
NvidiaUUID = str
|
22 |
-
CGroupPath = str
|
23 |
-
VisitorId = str
|
24 |
-
Score = float
|
25 |
-
|
26 |
-
|
27 |
-
class ScheduleResponse(BaseModel):
|
28 |
-
idle: bool
|
29 |
-
nvidiaIndex: int
|
30 |
-
nvidiaUUID: str
|
31 |
-
allowToken: str | None
|
32 |
-
|
33 |
-
|
34 |
-
class QuotaInfos(BaseModel):
|
35 |
-
left: int
|
36 |
-
wait: timedelta
|
37 |
-
|
38 |
-
|
39 |
-
class ReportUsageMonitoringParams(NamedTuple):
|
40 |
-
nvidia_index: int
|
41 |
-
visitor_id: str
|
42 |
-
duration: timedelta
|
43 |
-
|
44 |
-
|
45 |
-
class QueueEvent(BaseModel):
|
46 |
-
event: Literal['ping', 'failed', 'succeeded']
|
47 |
-
data: Optional[ScheduleResponse] = None
|
48 |
-
|
49 |
-
|
50 |
-
def sse_parse(text: str):
|
51 |
-
event, *data = text.strip().splitlines()
|
52 |
-
assert event.startswith('event:')
|
53 |
-
event = event[6:].strip()
|
54 |
-
if event in ('ping', 'failed'):
|
55 |
-
return QueueEvent(event=event)
|
56 |
-
assert event == 'succeeded'
|
57 |
-
(data,) = data
|
58 |
-
assert data.startswith('data:')
|
59 |
-
data = data[5:].strip()
|
60 |
-
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
|
61 |
-
|
62 |
-
|
63 |
-
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
|
64 |
-
for text in res.iter_text():
|
65 |
-
if len(text) == 0:
|
66 |
-
break # pragma: no cover
|
67 |
-
try:
|
68 |
-
yield sse_parse(text)
|
69 |
-
except GeneratorExit:
|
70 |
-
res.close()
|
71 |
-
break
|
72 |
-
|
73 |
-
|
74 |
-
class APIClient:
|
75 |
-
|
76 |
-
def __init__(self, client: httpx.Client):
|
77 |
-
self.client = client
|
78 |
-
|
79 |
-
def startup_report(self) -> httpx.codes:
|
80 |
-
res = self.client.post('/startup-report')
|
81 |
-
return httpx.codes(res.status_code)
|
82 |
-
|
83 |
-
def schedule(
|
84 |
-
self,
|
85 |
-
cgroup_path: str,
|
86 |
-
task_id: int = 0,
|
87 |
-
token: str | None = None,
|
88 |
-
duration_seconds: int | None = None,
|
89 |
-
enable_queue: bool = True,
|
90 |
-
):
|
91 |
-
params: dict[str, str | int | bool] = {
|
92 |
-
'cgroupPath': cgroup_path,
|
93 |
-
'taskId': task_id,
|
94 |
-
'enableQueue': enable_queue,
|
95 |
-
}
|
96 |
-
if duration_seconds is not None:
|
97 |
-
params['durationSeconds'] = duration_seconds
|
98 |
-
if token is not None:
|
99 |
-
params['token'] = token
|
100 |
-
res = self.client.send(
|
101 |
-
request=self.client.build_request(
|
102 |
-
method='POST',
|
103 |
-
url='/schedule',
|
104 |
-
params=params,
|
105 |
-
),
|
106 |
-
stream=True,
|
107 |
-
)
|
108 |
-
status = httpx.codes(res.status_code)
|
109 |
-
if (status is not httpx.codes.OK and
|
110 |
-
status is not httpx.codes.TOO_MANY_REQUESTS
|
111 |
-
):
|
112 |
-
res.close()
|
113 |
-
return status
|
114 |
-
if "text/event-stream" in res.headers['content-type']:
|
115 |
-
return sse_stream(res)
|
116 |
-
res.read()
|
117 |
-
if status is httpx.codes.TOO_MANY_REQUESTS:
|
118 |
-
return QuotaInfos(**res.json()) # pragma: no cover
|
119 |
-
if status is httpx.codes.OK:
|
120 |
-
return ScheduleResponse(**res.json())
|
121 |
-
assert_never(status)
|
122 |
-
|
123 |
-
def allow(
|
124 |
-
self,
|
125 |
-
allow_token: str,
|
126 |
-
pid: int,
|
127 |
-
):
|
128 |
-
res = self.client.post('/allow', params={
|
129 |
-
'allowToken': allow_token,
|
130 |
-
'pid': pid,
|
131 |
-
})
|
132 |
-
return httpx.codes(res.status_code)
|
133 |
-
|
134 |
-
def release(
|
135 |
-
self,
|
136 |
-
nvidia_index: int,
|
137 |
-
cgroup_path: str,
|
138 |
-
task_id: int = 0,
|
139 |
-
fail: bool = False,
|
140 |
-
) -> httpx.codes:
|
141 |
-
res = self.client.post('/release', params={
|
142 |
-
'nvidiaIndex': nvidia_index,
|
143 |
-
'cgroupPath': cgroup_path,
|
144 |
-
'taskId': task_id,
|
145 |
-
'fail': fail,
|
146 |
-
})
|
147 |
-
return httpx.codes(res.status_code)
|
148 |
-
|
149 |
-
def get_queue_size(self) -> int:
|
150 |
-
res = self.client.get('/queue-size')
|
151 |
-
assert res.status_code == 200, res.status_code
|
152 |
-
size = res.json()
|
153 |
-
assert isinstance(size, int)
|
154 |
-
return size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/bitsandbytes.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
# pyright: reportPrivateImportUsage=false
|
4 |
-
|
5 |
-
from __future__ import annotations
|
6 |
-
|
7 |
-
import importlib
|
8 |
-
from typing import TYPE_CHECKING
|
9 |
-
from typing import Tuple
|
10 |
-
|
11 |
-
from .utils import cuda_unavailable
|
12 |
-
from .utils import maybe_import_torch
|
13 |
-
from .utils import maybe_import_bitsandbytes
|
14 |
-
|
15 |
-
if TYPE_CHECKING:
|
16 |
-
import torch as Torch
|
17 |
-
|
18 |
-
|
19 |
-
if (torch := maybe_import_torch()) and (bnb := maybe_import_bitsandbytes()):
|
20 |
-
|
21 |
-
from torch.utils.weak import WeakTensorKeyDictionary
|
22 |
-
|
23 |
-
with cuda_unavailable(torch):
|
24 |
-
from bitsandbytes import cextension
|
25 |
-
from bitsandbytes import functional
|
26 |
-
try: # bitsandbytes < 0.44
|
27 |
-
from bitsandbytes.cuda_setup.main import CUDASetup
|
28 |
-
except ModuleNotFoundError: # pragma: no cover
|
29 |
-
CUDASetup = None
|
30 |
-
from bitsandbytes.nn import Int8Params
|
31 |
-
from bitsandbytes.nn import Params4bit
|
32 |
-
|
33 |
-
_param_to_8bit = Int8Params.to # type: ignore
|
34 |
-
_param_cuda_8bit = Int8Params.cuda
|
35 |
-
_param_to_4bit = Params4bit.to # type: ignore
|
36 |
-
_param_cuda_4bit = Params4bit.cuda
|
37 |
-
|
38 |
-
TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
|
39 |
-
|
40 |
-
to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
41 |
-
to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
42 |
-
|
43 |
-
def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
|
44 |
-
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
45 |
-
device, *_ = parsed
|
46 |
-
if not isinstance(device, torch.device): # pragma: no cover
|
47 |
-
return _param_to_8bit(self, *args, **kwargs)
|
48 |
-
if device.type != 'cuda':
|
49 |
-
return _param_to_8bit(self, *args, **kwargs)
|
50 |
-
to_ops_8bit[self] = parsed
|
51 |
-
return self
|
52 |
-
|
53 |
-
def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
|
54 |
-
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
55 |
-
device, *_ = parsed
|
56 |
-
if not isinstance(device, torch.device): # pragma: no cover
|
57 |
-
return _param_to_4bit(self, *args, **kwargs)
|
58 |
-
if device.type != 'cuda':
|
59 |
-
return _param_to_4bit(self, *args, **kwargs)
|
60 |
-
to_ops_4bit[self] = parsed
|
61 |
-
return self
|
62 |
-
|
63 |
-
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
64 |
-
if device is None: # pragma: no cover
|
65 |
-
return True
|
66 |
-
if isinstance(device, int):
|
67 |
-
return True
|
68 |
-
if isinstance(device, str): # pragma: no cover
|
69 |
-
device = torch.device(device)
|
70 |
-
return device.type == 'cuda' # pragma: no cover
|
71 |
-
|
72 |
-
def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
|
73 |
-
if not _cuda_op_arg_check(device): # pragma: no cover
|
74 |
-
# Let PyTorch handle the fail
|
75 |
-
return _param_cuda_8bit(self, device, **kwargs)
|
76 |
-
to_ops_8bit[self] = None
|
77 |
-
return self
|
78 |
-
|
79 |
-
def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
|
80 |
-
if not _cuda_op_arg_check(device): # pragma: no cover
|
81 |
-
# Let PyTorch handle the fail
|
82 |
-
return _param_cuda_4bit(self, device, **kwargs)
|
83 |
-
to_ops_4bit[self] = None
|
84 |
-
return self
|
85 |
-
|
86 |
-
def _patch():
|
87 |
-
Int8Params.to = _to_op_register_8bit # type: ignore
|
88 |
-
Int8Params.cuda = _cuda_op_register_8bit # type: ignore
|
89 |
-
Params4bit.to = _to_op_register_4bit # type: ignore
|
90 |
-
Params4bit.cuda = _cuda_op_register_4bit # type: ignore
|
91 |
-
|
92 |
-
def _unpatch():
|
93 |
-
Int8Params.to = _param_to_8bit # type: ignore
|
94 |
-
Int8Params.cuda = _param_cuda_8bit
|
95 |
-
Params4bit.to = _param_to_4bit # type: ignore
|
96 |
-
Params4bit.cuda = _param_cuda_4bit
|
97 |
-
|
98 |
-
def _move():
|
99 |
-
if CUDASetup is not None:
|
100 |
-
CUDASetup._instance = None
|
101 |
-
importlib.reload(cextension)
|
102 |
-
functional.lib = cextension.lib
|
103 |
-
for op in to_ops_8bit.items():
|
104 |
-
tensor, parsed_args = op
|
105 |
-
if parsed_args:
|
106 |
-
_, dtype, _, memory_format = parsed_args
|
107 |
-
else:
|
108 |
-
dtype, memory_format = None, None
|
109 |
-
tensor.data = _param_to_8bit(tensor,
|
110 |
-
device='cuda',
|
111 |
-
dtype=dtype,
|
112 |
-
memory_format=memory_format,
|
113 |
-
) # type: ignore
|
114 |
-
for op in to_ops_4bit.items():
|
115 |
-
tensor, parsed_args = op
|
116 |
-
if parsed_args:
|
117 |
-
_, dtype, _, memory_format = parsed_args
|
118 |
-
else:
|
119 |
-
dtype, memory_format = None, None
|
120 |
-
tensor.data = _param_to_4bit(tensor,
|
121 |
-
device='cuda',
|
122 |
-
dtype=dtype,
|
123 |
-
memory_format=memory_format,
|
124 |
-
) # type: ignore
|
125 |
-
|
126 |
-
else:
|
127 |
-
|
128 |
-
_patch = lambda: None
|
129 |
-
_unpatch = lambda: None
|
130 |
-
_move = lambda: None
|
131 |
-
|
132 |
-
|
133 |
-
patch = _patch
|
134 |
-
unpatch = _unpatch
|
135 |
-
move = _move
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/client.py
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import os
|
6 |
-
import time
|
7 |
-
import warnings
|
8 |
-
from datetime import timedelta
|
9 |
-
|
10 |
-
import gradio as gr
|
11 |
-
import httpx
|
12 |
-
|
13 |
-
from .. import utils
|
14 |
-
from ..config import Config
|
15 |
-
from .api import APIClient
|
16 |
-
from .api import QuotaInfos
|
17 |
-
from .api import ScheduleResponse
|
18 |
-
from .gradio import get_event
|
19 |
-
|
20 |
-
|
21 |
-
TOKEN_HEADER = 'X-IP-Token'
|
22 |
-
DEFAULT_SCHEDULE_DURATION = 60
|
23 |
-
|
24 |
-
QUOTA_MESSAGE = "You have exceeded your GPU quota"
|
25 |
-
UNUSED_MESSAGE = "GPU device not used"
|
26 |
-
NO_GPU_MESSAGE_REGULAR = "No GPU is currently available"
|
27 |
-
NO_GPU_MESSAGE_INQUEUE = "No GPU is currently available for you after 60s"
|
28 |
-
|
29 |
-
|
30 |
-
def api_client():
|
31 |
-
assert Config.zero_device_api_url is not None
|
32 |
-
httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
|
33 |
-
return APIClient(httpx_client)
|
34 |
-
|
35 |
-
|
36 |
-
def startup_report():
|
37 |
-
retries, max_retries = 0, 2
|
38 |
-
client = api_client()
|
39 |
-
while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
|
40 |
-
time.sleep(1)
|
41 |
-
if (retries := retries + 1) > max_retries:
|
42 |
-
raise RuntimeError("Error while initializing ZeroGPU: NotFound")
|
43 |
-
if status is not httpx.codes.OK: # pragma: no cover
|
44 |
-
raise RuntimeError("Error while initializing ZeroGPU: Unknown")
|
45 |
-
|
46 |
-
|
47 |
-
def schedule(
|
48 |
-
task_id: int,
|
49 |
-
request: gr.Request | None = None,
|
50 |
-
duration: timedelta | None = None,
|
51 |
-
_first_attempt: bool = True,
|
52 |
-
) -> ScheduleResponse:
|
53 |
-
|
54 |
-
if not gr.__version__.startswith('4.'): # pragma: no cover
|
55 |
-
raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
|
56 |
-
|
57 |
-
res = api_client().schedule(
|
58 |
-
cgroup_path=utils.self_cgroup_device_path(),
|
59 |
-
task_id=task_id,
|
60 |
-
token=_get_token(request),
|
61 |
-
duration_seconds=duration.seconds if duration is not None else None,
|
62 |
-
)
|
63 |
-
|
64 |
-
if isinstance(res, ScheduleResponse):
|
65 |
-
return res
|
66 |
-
|
67 |
-
if isinstance(res, QuotaInfos): # pragma: no cover
|
68 |
-
requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
|
69 |
-
if res.wait < timedelta(0):
|
70 |
-
message = (
|
71 |
-
f"The requested GPU duration ({requested}s) "
|
72 |
-
f"is larger than the maximum allowed"
|
73 |
-
)
|
74 |
-
else:
|
75 |
-
message = (
|
76 |
-
f"You have exceeded your GPU quota "
|
77 |
-
f"({res.left}s left vs. {requested}s requested). "
|
78 |
-
f"Please retry in {res.wait}"
|
79 |
-
)
|
80 |
-
raise gr.Error(message)
|
81 |
-
|
82 |
-
if not isinstance(res, httpx.codes): # pragma: no cover
|
83 |
-
gr.Info("Waiting for a GPU to become available")
|
84 |
-
connection_event = get_event()
|
85 |
-
if connection_event is None and request is not None:
|
86 |
-
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
87 |
-
while True:
|
88 |
-
try:
|
89 |
-
event = next(res)
|
90 |
-
except StopIteration:
|
91 |
-
raise RuntimeError("Unexpected end of stream")
|
92 |
-
except httpx.RemoteProtocolError:
|
93 |
-
if not _first_attempt:
|
94 |
-
raise RuntimeError("Error while re-trying after queue disconnect")
|
95 |
-
return schedule(task_id, request, duration, _first_attempt=False)
|
96 |
-
if event.event == 'ping':
|
97 |
-
if connection_event is not None and not connection_event.alive:
|
98 |
-
res.close()
|
99 |
-
raise RuntimeError("Connection closed by visitor while queueing")
|
100 |
-
continue
|
101 |
-
if event.event == 'failed':
|
102 |
-
raise gr.Error(NO_GPU_MESSAGE_INQUEUE)
|
103 |
-
if event.event == 'succeeded':
|
104 |
-
assert event.data is not None
|
105 |
-
if connection_event is not None and not connection_event.alive:
|
106 |
-
release(task_id, event.data.nvidiaIndex)
|
107 |
-
raise RuntimeError("Connection closed by visitor on queue success")
|
108 |
-
gr.Info("Successfully acquired a GPU")
|
109 |
-
return event.data
|
110 |
-
|
111 |
-
if res is httpx.codes.SERVICE_UNAVAILABLE:
|
112 |
-
raise gr.Error(NO_GPU_MESSAGE_REGULAR)
|
113 |
-
|
114 |
-
# TODO: Find a way to log 'detail' response field
|
115 |
-
raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
116 |
-
|
117 |
-
|
118 |
-
def allow(allow_token: str) -> None:
|
119 |
-
pid = os.getpid()
|
120 |
-
assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
|
121 |
-
assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
|
122 |
-
|
123 |
-
|
124 |
-
def release(
|
125 |
-
task_id: int,
|
126 |
-
nvidia_index: int,
|
127 |
-
fail: bool = False,
|
128 |
-
allow_404: bool = False,
|
129 |
-
) -> None:
|
130 |
-
|
131 |
-
res = api_client().release(
|
132 |
-
cgroup_path=utils.self_cgroup_device_path(),
|
133 |
-
task_id=task_id,
|
134 |
-
nvidia_index=nvidia_index,
|
135 |
-
fail=fail,
|
136 |
-
)
|
137 |
-
|
138 |
-
if res is httpx.codes.NO_CONTENT: # pragma: no cover
|
139 |
-
try:
|
140 |
-
gr.Warning(UNUSED_MESSAGE)
|
141 |
-
except AttributeError:
|
142 |
-
pass
|
143 |
-
warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
|
144 |
-
return None
|
145 |
-
|
146 |
-
if res is httpx.codes.NOT_FOUND:
|
147 |
-
if not allow_404:
|
148 |
-
warnings.warn("ZeroGPU API /release warning: 404 Not Found")
|
149 |
-
return None
|
150 |
-
|
151 |
-
if httpx.codes.is_success(res):
|
152 |
-
return None
|
153 |
-
|
154 |
-
# TODO: Find a way to log 'detail' response field
|
155 |
-
# TODO: Only raise in dev environment. Simply warn in production ?
|
156 |
-
raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
157 |
-
|
158 |
-
|
159 |
-
def _get_token(request: gr.Request | None) -> str | None:
|
160 |
-
|
161 |
-
if request is None:
|
162 |
-
return None
|
163 |
-
|
164 |
-
headers = getattr(request, 'headers', None)
|
165 |
-
if headers is None or not hasattr(headers, '__dict__'):
|
166 |
-
raise gr.Error("Internal Gradio error")
|
167 |
-
|
168 |
-
# Compatibility trick
|
169 |
-
if not hasattr(headers, 'get'):
|
170 |
-
headers = headers.__dict__ # pragma: no cover
|
171 |
-
|
172 |
-
if not (token := headers.get(TOKEN_HEADER.lower())):
|
173 |
-
raise gr.Error("Internal infra error")
|
174 |
-
|
175 |
-
return token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/decorator.py
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import inspect
|
6 |
-
import sys
|
7 |
-
import warnings
|
8 |
-
from datetime import timedelta
|
9 |
-
from functools import partial
|
10 |
-
from typing import Callable
|
11 |
-
from typing import TypeVar
|
12 |
-
from typing import overload
|
13 |
-
from typing_extensions import ParamSpec
|
14 |
-
from typing_extensions import Unpack
|
15 |
-
|
16 |
-
import gradio as gr
|
17 |
-
|
18 |
-
from ..config import Config
|
19 |
-
from . import client
|
20 |
-
from .types import EmptyKwargs
|
21 |
-
from .wrappers import regular_function_wrapper
|
22 |
-
from .wrappers import generator_function_wrapper
|
23 |
-
|
24 |
-
|
25 |
-
P = ParamSpec('P')
|
26 |
-
R = TypeVar('R')
|
27 |
-
|
28 |
-
|
29 |
-
decorated_cache: dict[Callable, Callable] = {}
|
30 |
-
|
31 |
-
|
32 |
-
@overload
|
33 |
-
def GPU(
|
34 |
-
task: None = None, *,
|
35 |
-
duration: int | timedelta | None = None,
|
36 |
-
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
37 |
-
...
|
38 |
-
@overload
|
39 |
-
def GPU(
|
40 |
-
task: Callable[P, R], *,
|
41 |
-
duration: int | timedelta | None = None,
|
42 |
-
) -> Callable[P, R]:
|
43 |
-
...
|
44 |
-
def GPU(
|
45 |
-
task: Callable[P, R] | None = None, *,
|
46 |
-
duration: int | timedelta | None = None,
|
47 |
-
**kwargs: Unpack[EmptyKwargs],
|
48 |
-
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
|
49 |
-
"""
|
50 |
-
ZeroGPU decorator
|
51 |
-
|
52 |
-
Basic usage:
|
53 |
-
```
|
54 |
-
@spaces.GPU
|
55 |
-
def fn(...):
|
56 |
-
# CUDA is available here
|
57 |
-
pass
|
58 |
-
```
|
59 |
-
|
60 |
-
With custom duration:
|
61 |
-
```
|
62 |
-
@spaces.GPU(duration=45) # Expressed in seconds
|
63 |
-
def fn(...):
|
64 |
-
# CUDA is available here
|
65 |
-
pass
|
66 |
-
```
|
67 |
-
|
68 |
-
Args:
|
69 |
-
task (`Callable | None`): Python function that requires CUDA
|
70 |
-
duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
`Callable`: GPU-ready function
|
74 |
-
"""
|
75 |
-
if "enable_queue" in kwargs:
|
76 |
-
warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
|
77 |
-
if task is None:
|
78 |
-
return partial(_GPU, duration=duration)
|
79 |
-
return _GPU(task, duration)
|
80 |
-
|
81 |
-
|
82 |
-
def _GPU(
|
83 |
-
task: Callable[P, R],
|
84 |
-
duration: int | timedelta | None,
|
85 |
-
) -> Callable[P, R]:
|
86 |
-
|
87 |
-
if not Config.zero_gpu:
|
88 |
-
# TODO: still prepend gr.Request for type consistency ?
|
89 |
-
return task # type: ignore
|
90 |
-
|
91 |
-
if sys.version_info.minor < 9: # pragma: no cover
|
92 |
-
raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
|
93 |
-
|
94 |
-
if task in decorated_cache:
|
95 |
-
# TODO: Assert same duration ?
|
96 |
-
return decorated_cache[task] # type: ignore
|
97 |
-
|
98 |
-
if inspect.iscoroutinefunction(task):
|
99 |
-
raise NotImplementedError
|
100 |
-
|
101 |
-
if duration is None or isinstance(duration, timedelta):
|
102 |
-
timedelta_duration = duration
|
103 |
-
else:
|
104 |
-
timedelta_duration = timedelta(seconds=duration)
|
105 |
-
|
106 |
-
if inspect.isgeneratorfunction(task):
|
107 |
-
decorated = generator_function_wrapper(task, timedelta_duration)
|
108 |
-
else:
|
109 |
-
decorated = regular_function_wrapper(task, timedelta_duration)
|
110 |
-
|
111 |
-
client.startup_report()
|
112 |
-
decorated_cache.update({
|
113 |
-
task: decorated,
|
114 |
-
decorated: decorated,
|
115 |
-
})
|
116 |
-
|
117 |
-
return decorated # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/gradio.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
from typing import NamedTuple
|
6 |
-
import warnings
|
7 |
-
|
8 |
-
from gradio.context import LocalContext
|
9 |
-
from gradio.helpers import Progress
|
10 |
-
from gradio.helpers import TrackedIterable
|
11 |
-
from gradio.queueing import Queue
|
12 |
-
from typing_extensions import assert_type
|
13 |
-
|
14 |
-
from ..utils import SimpleQueue
|
15 |
-
from .types import GeneratorResQueueResult
|
16 |
-
from .types import GradioQueueEvent
|
17 |
-
from .types import RegularResQueueResult
|
18 |
-
|
19 |
-
|
20 |
-
QUEUE_RPC_METHODS = [
|
21 |
-
"set_progress",
|
22 |
-
"log_message",
|
23 |
-
]
|
24 |
-
|
25 |
-
|
26 |
-
class GradioPartialContext(NamedTuple):
|
27 |
-
event_id: str | None
|
28 |
-
in_event_listener: bool
|
29 |
-
progress: Progress | None
|
30 |
-
|
31 |
-
@staticmethod
|
32 |
-
def get():
|
33 |
-
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
34 |
-
return GradioPartialContext(
|
35 |
-
event_id=LocalContext.event_id.get(),
|
36 |
-
in_event_listener=LocalContext.in_event_listener.get(),
|
37 |
-
progress=LocalContext.progress.get(),
|
38 |
-
)
|
39 |
-
|
40 |
-
@staticmethod
|
41 |
-
def apply(context: 'GradioPartialContext'):
|
42 |
-
LocalContext.event_id.set(context.event_id)
|
43 |
-
LocalContext.in_event_listener.set(context.in_event_listener)
|
44 |
-
LocalContext.progress.set(context.progress)
|
45 |
-
|
46 |
-
|
47 |
-
def get_queue_instance():
|
48 |
-
blocks = LocalContext.blocks.get()
|
49 |
-
if blocks is None: # pragma: no cover
|
50 |
-
return None
|
51 |
-
return blocks._queue
|
52 |
-
|
53 |
-
|
54 |
-
def get_event():
|
55 |
-
queue = get_queue_instance()
|
56 |
-
event_id = LocalContext.event_id.get()
|
57 |
-
if queue is None:
|
58 |
-
return None
|
59 |
-
if event_id is None: # pragma: no cover
|
60 |
-
return None
|
61 |
-
for job in queue.active_jobs:
|
62 |
-
if job is None: # pragma: no cover
|
63 |
-
continue
|
64 |
-
for event in job:
|
65 |
-
if event._id == event_id:
|
66 |
-
return event
|
67 |
-
|
68 |
-
|
69 |
-
def try_process_queue_event(method_name: str, *args, **kwargs):
|
70 |
-
queue = get_queue_instance()
|
71 |
-
if queue is None: # pragma: no cover
|
72 |
-
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
73 |
-
return
|
74 |
-
method = getattr(queue, method_name, None)
|
75 |
-
assert callable(method)
|
76 |
-
method(*args, **kwargs)
|
77 |
-
|
78 |
-
|
79 |
-
def patch_gradio_queue(
|
80 |
-
res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
|
81 |
-
):
|
82 |
-
|
83 |
-
def rpc_method(method_name: str):
|
84 |
-
def method(*args, **kwargs):
|
85 |
-
if args and isinstance(args[0], Queue):
|
86 |
-
args = args[1:] # drop `self`
|
87 |
-
res_queue.put(GradioQueueEvent(method_name, args, kwargs))
|
88 |
-
return method
|
89 |
-
|
90 |
-
for method_name in QUEUE_RPC_METHODS:
|
91 |
-
if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
|
92 |
-
warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
|
93 |
-
continue
|
94 |
-
if not callable(method): # pragma: no cover
|
95 |
-
warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
|
96 |
-
continue
|
97 |
-
setattr(Queue, method_name, rpc_method(method_name))
|
98 |
-
|
99 |
-
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
100 |
-
|
101 |
-
|
102 |
-
def tracked_iterable__reduce__(self):
|
103 |
-
res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
|
104 |
-
cls, base, state, *_ = res
|
105 |
-
return cls, base,{**state, **{
|
106 |
-
'iterable': None,
|
107 |
-
'_tqdm': None,
|
108 |
-
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/torch.py
DELETED
@@ -1,279 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
# pyright: reportPrivateImportUsage=false
|
4 |
-
|
5 |
-
from __future__ import annotations
|
6 |
-
|
7 |
-
import multiprocessing
|
8 |
-
import os
|
9 |
-
from concurrent.futures import ProcessPoolExecutor
|
10 |
-
from contextlib import suppress
|
11 |
-
from functools import partial
|
12 |
-
from types import SimpleNamespace
|
13 |
-
from typing import TYPE_CHECKING
|
14 |
-
from typing import Any
|
15 |
-
from typing import Optional
|
16 |
-
from typing import Tuple
|
17 |
-
|
18 |
-
from ..config import Config
|
19 |
-
from . import bitsandbytes
|
20 |
-
from .utils import maybe_import_torch
|
21 |
-
|
22 |
-
if TYPE_CHECKING:
|
23 |
-
import torch as Torch
|
24 |
-
|
25 |
-
|
26 |
-
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
27 |
-
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
28 |
-
CUDA_TOTAL_MEMORY = 42144366592
|
29 |
-
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
30 |
-
CUDA_DEVICE_CAPABILITY = (8, 0)
|
31 |
-
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
32 |
-
|
33 |
-
GENERIC_METHOD_NAMES = [
|
34 |
-
'arange',
|
35 |
-
'as_tensor',
|
36 |
-
'asarray',
|
37 |
-
'bartlett_window',
|
38 |
-
'blackman_window',
|
39 |
-
'empty',
|
40 |
-
'empty_like',
|
41 |
-
'empty_strided',
|
42 |
-
'eye',
|
43 |
-
'full',
|
44 |
-
'full_like',
|
45 |
-
'hamming_window',
|
46 |
-
'hann_window',
|
47 |
-
'kaiser_window',
|
48 |
-
'linspace',
|
49 |
-
'logspace',
|
50 |
-
'obj',
|
51 |
-
'ones',
|
52 |
-
'ones_like',
|
53 |
-
'rand',
|
54 |
-
'rand_like',
|
55 |
-
'randint',
|
56 |
-
'randint_like',
|
57 |
-
'randn',
|
58 |
-
'randn_like',
|
59 |
-
'randperm',
|
60 |
-
'range',
|
61 |
-
'sparse_bsc_tensor',
|
62 |
-
'sparse_bsr_tensor',
|
63 |
-
'sparse_compressed_tensor',
|
64 |
-
'sparse_coo_tensor',
|
65 |
-
'sparse_csc_tensor',
|
66 |
-
'sparse_csr_tensor',
|
67 |
-
'tensor',
|
68 |
-
'tril_indices',
|
69 |
-
'triu_indices',
|
70 |
-
'zeros',
|
71 |
-
'zeros_like',
|
72 |
-
]
|
73 |
-
|
74 |
-
|
75 |
-
if (torch := maybe_import_torch()):
|
76 |
-
|
77 |
-
from torch.utils.weak import WeakTensorKeyDictionary
|
78 |
-
|
79 |
-
TO_CUDA = (torch.device('cuda'), None, False, None)
|
80 |
-
|
81 |
-
_tensor__deepcopy__ = torch.Tensor.__deepcopy__
|
82 |
-
_tensor_to = torch.Tensor.to
|
83 |
-
_tensor_cuda = torch.Tensor.cuda
|
84 |
-
_tensor_cpu = torch.Tensor.cpu
|
85 |
-
_torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
|
86 |
-
_cuda_init = torch._C._cuda_init
|
87 |
-
_cuda_available = torch.cuda.is_available
|
88 |
-
_cuda_device_count = torch.cuda.device_count
|
89 |
-
_cuda_current_device = torch.cuda.current_device
|
90 |
-
_cuda_mem_get_info = torch.cuda.mem_get_info
|
91 |
-
_cuda_get_device_capability = torch.cuda.get_device_capability
|
92 |
-
_cuda_get_device_properties = torch.cuda.get_device_properties
|
93 |
-
_cuda_get_device_name = torch.cuda.get_device_name
|
94 |
-
|
95 |
-
TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
|
96 |
-
|
97 |
-
to_ops: dict[Torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
|
98 |
-
|
99 |
-
def _tensor_new_register(*args, **kwargs):
|
100 |
-
new_tensor: Torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
|
101 |
-
if (base_tensor := new_tensor._base) is not None:
|
102 |
-
if base_tensor in to_ops:
|
103 |
-
to_ops[new_tensor] = to_ops[base_tensor]
|
104 |
-
return new_tensor
|
105 |
-
|
106 |
-
def _tensor_deepcopy_register(self: Torch.Tensor, memo):
|
107 |
-
new_tensor = _tensor__deepcopy__(self, memo)
|
108 |
-
if isinstance(new_tensor, torch.Tensor):
|
109 |
-
if self in to_ops:
|
110 |
-
to_ops[new_tensor] = to_ops[self]
|
111 |
-
return new_tensor
|
112 |
-
|
113 |
-
@property
|
114 |
-
def _tensor_device_property(self: Torch.Tensor):
|
115 |
-
if self in to_ops:
|
116 |
-
return torch.device(type='cuda', index=0)
|
117 |
-
del torch.Tensor.device
|
118 |
-
try:
|
119 |
-
return self.device
|
120 |
-
finally:
|
121 |
-
torch.Tensor.device = _tensor_device_property # type: ignore
|
122 |
-
|
123 |
-
@property
|
124 |
-
def _tensor_dtype_property(self: Torch.Tensor):
|
125 |
-
if self in to_ops:
|
126 |
-
if (to_dtype := to_ops[self][1]) is not None:
|
127 |
-
return to_dtype
|
128 |
-
del torch.Tensor.dtype
|
129 |
-
try:
|
130 |
-
return self.dtype
|
131 |
-
finally:
|
132 |
-
torch.Tensor.dtype = _tensor_dtype_property # type: ignore
|
133 |
-
|
134 |
-
def _to_op_register(self: Torch.Tensor, *args, **kwargs):
|
135 |
-
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
136 |
-
device, dtype, *_ = parsed
|
137 |
-
try:
|
138 |
-
to_args = to_ops.pop(self)
|
139 |
-
except KeyError:
|
140 |
-
to_args = None
|
141 |
-
if device is None:
|
142 |
-
if to_args is not None:
|
143 |
-
to_ops[self] = (to_args[0], dtype, *to_args[2:])
|
144 |
-
return self
|
145 |
-
return _tensor_to(self, *args, **kwargs)
|
146 |
-
if device.type != 'cuda':
|
147 |
-
if to_args is not None:
|
148 |
-
if (to_dtype := to_args[1]) is not None:
|
149 |
-
kwargs = {'dtype': to_dtype, **kwargs}
|
150 |
-
return _tensor_to(self, *args, **kwargs)
|
151 |
-
to_ops[self] = parsed
|
152 |
-
return self
|
153 |
-
|
154 |
-
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
155 |
-
if device is None:
|
156 |
-
return True
|
157 |
-
if isinstance(device, int):
|
158 |
-
return True
|
159 |
-
if isinstance(device, str):
|
160 |
-
device = torch.device(device)
|
161 |
-
return device.type == 'cuda'
|
162 |
-
|
163 |
-
def _cuda_op_register(self: Torch.Tensor, device: Torch.device | int | str | None = None, **kwargs):
|
164 |
-
if not _cuda_op_arg_check(device):
|
165 |
-
# Let PyTorch handle the fail
|
166 |
-
return _tensor_cuda(self, device, **kwargs)
|
167 |
-
to_ops[self] = TO_CUDA
|
168 |
-
return self
|
169 |
-
|
170 |
-
def _cpu_op_remove(self: Torch.Tensor, **kwargs):
|
171 |
-
try:
|
172 |
-
to_args = to_ops.pop(self)
|
173 |
-
except KeyError:
|
174 |
-
to_args = None
|
175 |
-
if to_args is not None:
|
176 |
-
if (to_dtype := to_args[1]) is not None:
|
177 |
-
return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
|
178 |
-
return _tensor_cpu(self, **kwargs)
|
179 |
-
|
180 |
-
def _cuda_init_raise():
|
181 |
-
raise RuntimeError(
|
182 |
-
"CUDA must not be initialized in the main process "
|
183 |
-
"on Spaces with Stateless GPU environment.\n"
|
184 |
-
"You can look at this Stacktrace to find out "
|
185 |
-
"which part of your code triggered a CUDA init"
|
186 |
-
)
|
187 |
-
|
188 |
-
def _generic_method_register(name: str, *args: Any, **kwargs: Any):
|
189 |
-
try:
|
190 |
-
device = torch.device(kwargs.get('device', "cpu"))
|
191 |
-
except Exception:
|
192 |
-
return _torch_generics[name](*args, **kwargs)
|
193 |
-
if device.type != 'cuda':
|
194 |
-
return _torch_generics[name](*args, **kwargs)
|
195 |
-
tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
|
196 |
-
to_ops[tensor] = TO_CUDA
|
197 |
-
return tensor
|
198 |
-
|
199 |
-
def _patch():
|
200 |
-
torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
|
201 |
-
torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
|
202 |
-
torch.Tensor.to = _to_op_register # type: ignore
|
203 |
-
torch.Tensor.cuda = _cuda_op_register # type: ignore
|
204 |
-
torch.Tensor.cpu = _cpu_op_remove # type: ignore
|
205 |
-
if Config.zero_patch_torch_device:
|
206 |
-
torch.Tensor.device = _tensor_device_property # type: ignore
|
207 |
-
torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
|
208 |
-
for name in GENERIC_METHOD_NAMES:
|
209 |
-
setattr(torch, name, partial(_generic_method_register, name))
|
210 |
-
torch._C._cuda_init = _cuda_init_raise
|
211 |
-
torch.cuda.is_available = lambda: True
|
212 |
-
torch.cuda.device_count = lambda: 1
|
213 |
-
torch.cuda.current_device = lambda: 0
|
214 |
-
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
215 |
-
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
216 |
-
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
217 |
-
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
218 |
-
bitsandbytes.patch()
|
219 |
-
|
220 |
-
def _unpatch():
|
221 |
-
torch.Tensor.__deepcopy__ = _tensor__deepcopy__
|
222 |
-
with suppress(AttributeError):
|
223 |
-
del torch.Tensor.__new__
|
224 |
-
torch.Tensor.to = _tensor_to
|
225 |
-
torch.Tensor.cuda = _tensor_cuda
|
226 |
-
torch.Tensor.cpu = _tensor_cpu
|
227 |
-
with suppress(AttributeError):
|
228 |
-
del torch.Tensor.device
|
229 |
-
with suppress(AttributeError):
|
230 |
-
del torch.Tensor.dtype
|
231 |
-
for name in GENERIC_METHOD_NAMES:
|
232 |
-
setattr(torch, name, _torch_generics[name])
|
233 |
-
torch._C._cuda_init = _cuda_init
|
234 |
-
torch.cuda.is_available = _cuda_available
|
235 |
-
torch.cuda.device_count = _cuda_device_count
|
236 |
-
torch.cuda.current_device = _cuda_current_device
|
237 |
-
torch.cuda.mem_get_info = _cuda_mem_get_info
|
238 |
-
torch.cuda.get_device_capability = _cuda_get_device_capability
|
239 |
-
torch.cuda.get_device_properties = _cuda_get_device_properties
|
240 |
-
torch.cuda.get_device_name = _cuda_get_device_name
|
241 |
-
bitsandbytes.unpatch()
|
242 |
-
|
243 |
-
def _move(nvidia_uuid: str):
|
244 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
245 |
-
torch.Tensor([0]).cuda() # CUDA init
|
246 |
-
for op in to_ops.items():
|
247 |
-
tensor, parsed_args = op
|
248 |
-
_, dtype, _, memory_format = parsed_args
|
249 |
-
tensor.data = _tensor_to(tensor,
|
250 |
-
device='cuda',
|
251 |
-
dtype=dtype,
|
252 |
-
memory_format=memory_format,
|
253 |
-
) # type: ignore
|
254 |
-
bitsandbytes.move()
|
255 |
-
torch.cuda.synchronize()
|
256 |
-
|
257 |
-
def _is_in_bad_fork():
|
258 |
-
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
259 |
-
f = e.submit(torch.cuda._is_in_bad_fork)
|
260 |
-
return f.result()
|
261 |
-
|
262 |
-
def _disable_cuda_intercept():
|
263 |
-
torch.Tensor.to = _tensor_to
|
264 |
-
torch.Tensor.cuda = _tensor_cuda
|
265 |
-
|
266 |
-
else:
|
267 |
-
|
268 |
-
_patch = lambda: None
|
269 |
-
_unpatch = lambda: None
|
270 |
-
_move = lambda nvidia_uuid: None
|
271 |
-
_is_in_bad_fork = lambda: False
|
272 |
-
_disable_cuda_intercept = lambda: None
|
273 |
-
|
274 |
-
|
275 |
-
patch = _patch
|
276 |
-
unpatch = _unpatch
|
277 |
-
move = _move
|
278 |
-
is_in_bad_fork = _is_in_bad_fork
|
279 |
-
disable_cuda_intercept = _disable_cuda_intercept
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/tqdm.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
|
4 |
-
from multiprocessing.synchronize import RLock as MultiprocessingRLock
|
5 |
-
|
6 |
-
|
7 |
-
def remove_tqdm_multiprocessing_lock():
|
8 |
-
from tqdm import tqdm
|
9 |
-
tqdm_lock = tqdm.get_lock()
|
10 |
-
assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
|
11 |
-
tqdm_lock.locks = [
|
12 |
-
lock for lock in tqdm_lock.locks
|
13 |
-
if not isinstance(lock, MultiprocessingRLock)
|
14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/types.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
|
6 |
-
from dataclasses import dataclass
|
7 |
-
from typing import Any
|
8 |
-
from typing import Dict
|
9 |
-
from typing import Tuple
|
10 |
-
from typing import TypedDict
|
11 |
-
from typing_extensions import Generic
|
12 |
-
from typing_extensions import ParamSpec
|
13 |
-
from typing_extensions import TypeAlias
|
14 |
-
from typing_extensions import TypeVar
|
15 |
-
|
16 |
-
|
17 |
-
Params = Tuple[Tuple[object, ...], Dict[str, Any]]
|
18 |
-
Res = TypeVar('Res')
|
19 |
-
Param = ParamSpec('Param')
|
20 |
-
|
21 |
-
class EmptyKwargs(TypedDict):
|
22 |
-
pass
|
23 |
-
|
24 |
-
@dataclass
|
25 |
-
class OkResult(Generic[Res]):
|
26 |
-
value: Res
|
27 |
-
@dataclass
|
28 |
-
class ExceptionResult:
|
29 |
-
value: Exception
|
30 |
-
@dataclass
|
31 |
-
class AbortedResult:
|
32 |
-
pass
|
33 |
-
@dataclass
|
34 |
-
class EndResult:
|
35 |
-
pass
|
36 |
-
@dataclass
|
37 |
-
class GradioQueueEvent:
|
38 |
-
method_name: str
|
39 |
-
args: tuple[Any, ...]
|
40 |
-
kwargs: dict[str, Any]
|
41 |
-
|
42 |
-
RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
|
43 |
-
GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
|
44 |
-
YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/utils.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
from contextlib import contextmanager
|
6 |
-
from importlib import metadata
|
7 |
-
from types import ModuleType
|
8 |
-
|
9 |
-
from packaging import version
|
10 |
-
|
11 |
-
from ..config import Config
|
12 |
-
|
13 |
-
|
14 |
-
def maybe_import_torch():
|
15 |
-
if not Config.zero_gpu:
|
16 |
-
return None
|
17 |
-
try:
|
18 |
-
import torch
|
19 |
-
except ImportError:
|
20 |
-
return None
|
21 |
-
return torch
|
22 |
-
|
23 |
-
|
24 |
-
@contextmanager
|
25 |
-
def cuda_unavailable(torch: ModuleType):
|
26 |
-
_is_available = torch.cuda.is_available
|
27 |
-
torch.cuda.is_available = lambda: False
|
28 |
-
yield
|
29 |
-
torch.cuda.is_available = _is_available
|
30 |
-
|
31 |
-
|
32 |
-
def maybe_import_bitsandbytes():
|
33 |
-
if (torch := maybe_import_torch()) is None:
|
34 |
-
return None # pragma: no cover
|
35 |
-
with cuda_unavailable(torch):
|
36 |
-
try:
|
37 |
-
import bitsandbytes
|
38 |
-
except ImportError:
|
39 |
-
bitsandbytes = None
|
40 |
-
else:
|
41 |
-
if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
|
42 |
-
raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
|
43 |
-
print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
|
44 |
-
return bitsandbytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spaces/spaces/zero/wrappers.py
DELETED
@@ -1,347 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
"""
|
3 |
-
from __future__ import annotations
|
4 |
-
|
5 |
-
import multiprocessing
|
6 |
-
import os
|
7 |
-
import signal
|
8 |
-
import traceback
|
9 |
-
from concurrent.futures import ThreadPoolExecutor
|
10 |
-
from contextvars import copy_context
|
11 |
-
from datetime import timedelta
|
12 |
-
from functools import partial
|
13 |
-
from functools import wraps
|
14 |
-
from multiprocessing.context import ForkProcess
|
15 |
-
from pickle import PicklingError
|
16 |
-
from queue import Empty
|
17 |
-
from queue import Queue as ThreadQueue
|
18 |
-
from threading import Thread
|
19 |
-
from typing import TYPE_CHECKING
|
20 |
-
from typing import Callable
|
21 |
-
from typing import Generator
|
22 |
-
from typing import Generic
|
23 |
-
from typing_extensions import assert_never
|
24 |
-
|
25 |
-
import gradio as gr
|
26 |
-
import psutil
|
27 |
-
|
28 |
-
from ..utils import debug
|
29 |
-
from ..utils import drop_params
|
30 |
-
from ..utils import gradio_request_var
|
31 |
-
from ..utils import SimpleQueue as Queue
|
32 |
-
from . import client
|
33 |
-
from . import torch
|
34 |
-
from .api import AllowToken
|
35 |
-
from .api import NvidiaIndex
|
36 |
-
from .api import NvidiaUUID
|
37 |
-
from .gradio import GradioPartialContext
|
38 |
-
from .gradio import patch_gradio_queue
|
39 |
-
from .gradio import try_process_queue_event
|
40 |
-
from .tqdm import remove_tqdm_multiprocessing_lock
|
41 |
-
from .types import * # TODO: Please don't do that
|
42 |
-
|
43 |
-
|
44 |
-
GENERATOR_GLOBAL_TIMEOUT = 20 * 60
|
45 |
-
|
46 |
-
|
47 |
-
Process = multiprocessing.get_context('fork').Process
|
48 |
-
forked = False
|
49 |
-
|
50 |
-
|
51 |
-
class Worker(Generic[Res]):
|
52 |
-
process: ForkProcess
|
53 |
-
arg_queue: Queue[tuple[Params, GradioPartialContext]]
|
54 |
-
res_queue: Queue[Res | None]
|
55 |
-
_sentinel: Thread
|
56 |
-
|
57 |
-
def __init__(
|
58 |
-
self,
|
59 |
-
target: Callable[[
|
60 |
-
Queue[tuple[Params, GradioPartialContext]],
|
61 |
-
Queue[Res | None],
|
62 |
-
AllowToken | None,
|
63 |
-
NvidiaUUID,
|
64 |
-
list[int],
|
65 |
-
], None],
|
66 |
-
allow_token: str | None,
|
67 |
-
nvidia_uuid: str,
|
68 |
-
):
|
69 |
-
self._sentinel = Thread(target=self._close_on_exit)
|
70 |
-
self.arg_queue = Queue()
|
71 |
-
self.res_queue = Queue()
|
72 |
-
fds = [c.fd for c in psutil.Process().connections()]
|
73 |
-
args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
|
74 |
-
if TYPE_CHECKING:
|
75 |
-
target(*args)
|
76 |
-
self.process = Process(
|
77 |
-
target=target,
|
78 |
-
args=args,
|
79 |
-
daemon=True,
|
80 |
-
)
|
81 |
-
self.process.start()
|
82 |
-
self._sentinel.start()
|
83 |
-
|
84 |
-
def _close_on_exit(self):
|
85 |
-
self.process.join()
|
86 |
-
self.res_queue.put(None)
|
87 |
-
|
88 |
-
|
89 |
-
def worker_init(
|
90 |
-
res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
|
91 |
-
allow_token: str | None,
|
92 |
-
nvidia_uuid: str,
|
93 |
-
fds: list[int],
|
94 |
-
) -> None | ExceptionResult:
|
95 |
-
try: # Unrecoverable init part
|
96 |
-
if allow_token is not None:
|
97 |
-
client.allow(allow_token)
|
98 |
-
torch.unpatch()
|
99 |
-
torch.move(nvidia_uuid)
|
100 |
-
patch_gradio_queue(res_queue)
|
101 |
-
except Exception as e: # pragma: no cover
|
102 |
-
traceback.print_exc()
|
103 |
-
return ExceptionResult(e)
|
104 |
-
try:
|
105 |
-
remove_tqdm_multiprocessing_lock()
|
106 |
-
except Exception: # pragma: no cover
|
107 |
-
print("Error while trying to remove tqdm mp_lock:")
|
108 |
-
traceback.print_exc()
|
109 |
-
for fd in fds:
|
110 |
-
try:
|
111 |
-
os.close(fd)
|
112 |
-
except Exception as e: # pragma: no cover
|
113 |
-
if isinstance(e, OSError) and e.errno == 9:
|
114 |
-
continue
|
115 |
-
traceback.print_exc()
|
116 |
-
return ExceptionResult(e)
|
117 |
-
|
118 |
-
|
119 |
-
def regular_function_wrapper(
|
120 |
-
task: Callable[Param, Res],
|
121 |
-
duration: timedelta | None,
|
122 |
-
) -> Callable[Param, Res]:
|
123 |
-
|
124 |
-
request_var = gradio_request_var()
|
125 |
-
workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
|
126 |
-
task_id = id(task)
|
127 |
-
|
128 |
-
@wraps(task)
|
129 |
-
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
|
130 |
-
|
131 |
-
if forked:
|
132 |
-
return task(*args, **kwargs)
|
133 |
-
|
134 |
-
request = request_var.get()
|
135 |
-
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
|
136 |
-
allow_token = schedule_response.allowToken
|
137 |
-
nvidia_index = schedule_response.nvidiaIndex
|
138 |
-
nvidia_uuid = schedule_response.nvidiaUUID
|
139 |
-
release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
|
140 |
-
|
141 |
-
worker = workers.get(nvidia_index)
|
142 |
-
if worker is None or not worker.process.is_alive():
|
143 |
-
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
144 |
-
workers[nvidia_index] = worker
|
145 |
-
|
146 |
-
try:
|
147 |
-
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
148 |
-
except PicklingError:
|
149 |
-
release(fail=True)
|
150 |
-
# TODO: Better error message (check what arg / kwarg is problematic ?)
|
151 |
-
raise
|
152 |
-
|
153 |
-
while True:
|
154 |
-
res = worker.res_queue.get()
|
155 |
-
if res is None:
|
156 |
-
release(fail=True, allow_404=True)
|
157 |
-
raise gr.Error("GPU task aborted")
|
158 |
-
if isinstance(res, ExceptionResult):
|
159 |
-
release(fail=True)
|
160 |
-
raise res.value
|
161 |
-
if isinstance(res, OkResult):
|
162 |
-
release()
|
163 |
-
return res.value
|
164 |
-
if isinstance(res, GradioQueueEvent):
|
165 |
-
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
166 |
-
continue
|
167 |
-
assert_never(res)
|
168 |
-
|
169 |
-
|
170 |
-
def thread_wrapper(
|
171 |
-
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
172 |
-
res_queue: Queue[RegularResQueueResult[Res] | None],
|
173 |
-
allow_token: str | None,
|
174 |
-
nvidia_uuid: str,
|
175 |
-
fds: list[int],
|
176 |
-
):
|
177 |
-
global forked
|
178 |
-
forked = True
|
179 |
-
if (res := worker_init(
|
180 |
-
res_queue=res_queue,
|
181 |
-
allow_token=allow_token,
|
182 |
-
nvidia_uuid=nvidia_uuid,
|
183 |
-
fds=fds,
|
184 |
-
)) is not None: # pragma: no cover
|
185 |
-
res_queue.put(res)
|
186 |
-
return
|
187 |
-
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
188 |
-
while True:
|
189 |
-
try:
|
190 |
-
(args, kwargs), gradio_context = arg_queue.get()
|
191 |
-
except OSError:
|
192 |
-
break
|
193 |
-
GradioPartialContext.apply(gradio_context)
|
194 |
-
context = copy_context()
|
195 |
-
with ThreadPoolExecutor() as executor:
|
196 |
-
future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
|
197 |
-
try:
|
198 |
-
res = future.result()
|
199 |
-
except Exception as e:
|
200 |
-
traceback.print_exc()
|
201 |
-
res = ExceptionResult(e)
|
202 |
-
else:
|
203 |
-
res = OkResult(res)
|
204 |
-
try:
|
205 |
-
res_queue.put(res)
|
206 |
-
except PicklingError as e:
|
207 |
-
res_queue.put(ExceptionResult(e))
|
208 |
-
|
209 |
-
# https://github.com/python/cpython/issues/91002
|
210 |
-
if not hasattr(task, '__annotations__'):
|
211 |
-
gradio_handler.__annotations__ = {}
|
212 |
-
|
213 |
-
return gradio_handler
|
214 |
-
|
215 |
-
|
216 |
-
def generator_function_wrapper(
|
217 |
-
task: Callable[Param, Generator[Res, None, None]],
|
218 |
-
duration: timedelta | None,
|
219 |
-
) -> Callable[Param, Generator[Res, None, None]]:
|
220 |
-
|
221 |
-
request_var = gradio_request_var()
|
222 |
-
workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
|
223 |
-
task_id = id(task)
|
224 |
-
|
225 |
-
@wraps(task)
|
226 |
-
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
|
227 |
-
|
228 |
-
if forked:
|
229 |
-
yield from task(*args, **kwargs)
|
230 |
-
return
|
231 |
-
|
232 |
-
request = request_var.get()
|
233 |
-
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration)
|
234 |
-
allow_token = schedule_response.allowToken
|
235 |
-
nvidia_index = schedule_response.nvidiaIndex
|
236 |
-
nvidia_uuid = schedule_response.nvidiaUUID
|
237 |
-
release = partial(client.release, task_id=task_id, nvidia_index=nvidia_index)
|
238 |
-
|
239 |
-
worker = workers.get(nvidia_index)
|
240 |
-
if worker is None or not worker.process.is_alive():
|
241 |
-
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
242 |
-
workers[nvidia_index] = worker
|
243 |
-
|
244 |
-
try:
|
245 |
-
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
246 |
-
except PicklingError:
|
247 |
-
release(fail=True)
|
248 |
-
raise
|
249 |
-
|
250 |
-
yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
|
251 |
-
def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
|
252 |
-
while True:
|
253 |
-
res = worker.res_queue.get()
|
254 |
-
if res is None:
|
255 |
-
release(fail=True, allow_404=True)
|
256 |
-
yield_queue.put(AbortedResult())
|
257 |
-
return
|
258 |
-
if isinstance(res, ExceptionResult):
|
259 |
-
release(fail=True)
|
260 |
-
yield_queue.put(ExceptionResult(res.value))
|
261 |
-
return
|
262 |
-
if isinstance(res, EndResult):
|
263 |
-
release()
|
264 |
-
yield_queue.put(EndResult())
|
265 |
-
return
|
266 |
-
if isinstance(res, OkResult):
|
267 |
-
yield_queue.put(OkResult(res.value))
|
268 |
-
continue
|
269 |
-
if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
|
270 |
-
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
271 |
-
continue
|
272 |
-
debug(f"fill_yield_queue: assert_never({res=})")
|
273 |
-
assert_never(res)
|
274 |
-
from typing_extensions import assert_never
|
275 |
-
with ThreadPoolExecutor() as e:
|
276 |
-
f = e.submit(fill_yield_queue, worker)
|
277 |
-
f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
|
278 |
-
while True:
|
279 |
-
try:
|
280 |
-
res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
|
281 |
-
except Empty: # pragma: no cover
|
282 |
-
debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
|
283 |
-
raise
|
284 |
-
if isinstance(res, AbortedResult):
|
285 |
-
raise gr.Error("GPU task aborted")
|
286 |
-
if isinstance(res, ExceptionResult):
|
287 |
-
raise res.value
|
288 |
-
if isinstance(res, EndResult):
|
289 |
-
break
|
290 |
-
if isinstance(res, OkResult):
|
291 |
-
yield res.value
|
292 |
-
continue
|
293 |
-
debug(f"gradio_handler: assert_never({res=})")
|
294 |
-
assert_never(res)
|
295 |
-
|
296 |
-
|
297 |
-
def thread_wrapper(
|
298 |
-
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
299 |
-
res_queue: Queue[GeneratorResQueueResult[Res] | None],
|
300 |
-
allow_token: str | None,
|
301 |
-
nvidia_uuid: str,
|
302 |
-
fds: list[int],
|
303 |
-
):
|
304 |
-
global forked
|
305 |
-
forked = True
|
306 |
-
if (res := worker_init(
|
307 |
-
res_queue=res_queue,
|
308 |
-
allow_token=allow_token,
|
309 |
-
nvidia_uuid=nvidia_uuid,
|
310 |
-
fds=fds,
|
311 |
-
)) is not None: # pragma: no cover
|
312 |
-
res_queue.put(res)
|
313 |
-
return
|
314 |
-
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
315 |
-
while True:
|
316 |
-
try:
|
317 |
-
(args, kwargs), gradio_context = arg_queue.get()
|
318 |
-
except OSError:
|
319 |
-
break
|
320 |
-
def iterate():
|
321 |
-
gen = task(*args, **kwargs) # type: ignore
|
322 |
-
while True:
|
323 |
-
try:
|
324 |
-
res = next(gen)
|
325 |
-
except StopIteration:
|
326 |
-
break
|
327 |
-
except Exception as e:
|
328 |
-
res_queue.put(ExceptionResult(e))
|
329 |
-
break
|
330 |
-
try:
|
331 |
-
res_queue.put(OkResult(res))
|
332 |
-
except PicklingError as e:
|
333 |
-
res_queue.put(ExceptionResult(e))
|
334 |
-
break
|
335 |
-
else:
|
336 |
-
continue
|
337 |
-
GradioPartialContext.apply(gradio_context)
|
338 |
-
context = copy_context()
|
339 |
-
with ThreadPoolExecutor() as executor:
|
340 |
-
executor.submit(context.run, iterate)
|
341 |
-
res_queue.put(EndResult())
|
342 |
-
|
343 |
-
# https://github.com/python/cpython/issues/91002
|
344 |
-
if not hasattr(task, '__annotations__'):
|
345 |
-
gradio_handler.__annotations__ = {}
|
346 |
-
|
347 |
-
return gradio_handler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|