NDugar commited on
Commit
cb31bcc
1 Parent(s): ca951ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -22
app.py CHANGED
@@ -25,37 +25,21 @@ COLOUR_MODEL = "RGB"
25
 
26
  MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
27
  MODEL_FILE = "h2z-85epoch.pth"
28
-
29
- # Model Initalisation
30
- #shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
31
- #hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
32
- #miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
33
  model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
34
-
35
- #shinkai_model = Transformer()
36
- #hosoda_model = Transformer()
37
- #miyazaki_model = Transformer()
38
- #model = Transformer()
39
-
40
  enable_gpu = torch.cuda.is_available()
41
  map_location = torch.device("cuda") if enable_gpu else "cpu"
 
 
42
 
43
- model.load_state_dict(torch.load(model_hfhub, map_location=map_location))
44
-
 
45
  model.eval()
46
-
47
-
48
- # Functions
49
-
50
  def get_model():
51
  return model
52
-
53
-
54
  def adjust_image_for_model(img):
55
  logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
56
  return img
57
-
58
-
59
  def inference(img, style):
60
  img = adjust_image_for_model(img)
61
  input_image = img.convert(COLOUR_MODEL)
@@ -103,4 +87,6 @@ gr.Interface(
103
  examples=examples,
104
  allow_flagging="never",
105
  allow_screenshot=False,
106
- ).launch(enable_queue=True)
 
 
 
25
 
26
  MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
27
  MODEL_FILE = "h2z-85epoch.pth"
 
 
 
 
 
28
  model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
 
 
 
 
 
 
29
  enable_gpu = torch.cuda.is_available()
30
  map_location = torch.device("cuda") if enable_gpu else "cpu"
31
+ from huggingface_hub import hf_hub_download
32
+ from fastai.learner import load_learner
33
 
34
+ model = load_learner(
35
+ hf_hub_download("NDugar/horse_to_zebra_cycle_GAN", "model.pkl")
36
+ )
37
  model.eval()
 
 
 
 
38
  def get_model():
39
  return model
 
 
40
  def adjust_image_for_model(img):
41
  logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
42
  return img
 
 
43
  def inference(img, style):
44
  img = adjust_image_for_model(img)
45
  input_image = img.convert(COLOUR_MODEL)
 
87
  examples=examples,
88
  allow_flagging="never",
89
  allow_screenshot=False,
90
+ ).launch(enable_queue=True)
91
+
92
+