Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -25,37 +25,21 @@ COLOUR_MODEL = "RGB"
|
|
25 |
|
26 |
MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
|
27 |
MODEL_FILE = "h2z-85epoch.pth"
|
28 |
-
|
29 |
-
# Model Initalisation
|
30 |
-
#shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
|
31 |
-
#hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
|
32 |
-
#miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
|
33 |
model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
|
34 |
-
|
35 |
-
#shinkai_model = Transformer()
|
36 |
-
#hosoda_model = Transformer()
|
37 |
-
#miyazaki_model = Transformer()
|
38 |
-
#model = Transformer()
|
39 |
-
|
40 |
enable_gpu = torch.cuda.is_available()
|
41 |
map_location = torch.device("cuda") if enable_gpu else "cpu"
|
|
|
|
|
42 |
|
43 |
-
model
|
44 |
-
|
|
|
45 |
model.eval()
|
46 |
-
|
47 |
-
|
48 |
-
# Functions
|
49 |
-
|
50 |
def get_model():
|
51 |
return model
|
52 |
-
|
53 |
-
|
54 |
def adjust_image_for_model(img):
|
55 |
logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
|
56 |
return img
|
57 |
-
|
58 |
-
|
59 |
def inference(img, style):
|
60 |
img = adjust_image_for_model(img)
|
61 |
input_image = img.convert(COLOUR_MODEL)
|
@@ -103,4 +87,6 @@ gr.Interface(
|
|
103 |
examples=examples,
|
104 |
allow_flagging="never",
|
105 |
allow_screenshot=False,
|
106 |
-
).launch(enable_queue=True)
|
|
|
|
|
|
25 |
|
26 |
MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
|
27 |
MODEL_FILE = "h2z-85epoch.pth"
|
|
|
|
|
|
|
|
|
|
|
28 |
model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
enable_gpu = torch.cuda.is_available()
|
30 |
map_location = torch.device("cuda") if enable_gpu else "cpu"
|
31 |
+
from huggingface_hub import hf_hub_download
|
32 |
+
from fastai.learner import load_learner
|
33 |
|
34 |
+
model = load_learner(
|
35 |
+
hf_hub_download("NDugar/horse_to_zebra_cycle_GAN", "model.pkl")
|
36 |
+
)
|
37 |
model.eval()
|
|
|
|
|
|
|
|
|
38 |
def get_model():
|
39 |
return model
|
|
|
|
|
40 |
def adjust_image_for_model(img):
|
41 |
logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
|
42 |
return img
|
|
|
|
|
43 |
def inference(img, style):
|
44 |
img = adjust_image_for_model(img)
|
45 |
input_image = img.convert(COLOUR_MODEL)
|
|
|
87 |
examples=examples,
|
88 |
allow_flagging="never",
|
89 |
allow_screenshot=False,
|
90 |
+
).launch(enable_queue=True)
|
91 |
+
|
92 |
+
|