RamAnanth1 commited on
Commit
7da7768
1 Parent(s): e893bfe

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +37 -0
model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from basicsr.utils import img2tensor, tensor2img
8
+ from pytorch_lightning import seed_everything
9
+ from ldm.models.diffusion.plms import PLMSSampler
10
+ from ldm.modules.encoders.adapter import Adapter
11
+ from ldm.util import instantiate_from_config
12
+ from model_edge import pidinet
13
+ import gradio as gr
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def load_model_from_config(config, ckpt, verbose=False):
18
+ print(f"Loading model from {ckpt}")
19
+ pl_sd = torch.load(ckpt, map_location="cpu")
20
+ if "global_step" in pl_sd:
21
+ print(f"Global Step: {pl_sd['global_step']}")
22
+ if "state_dict" in pl_sd:
23
+ sd = pl_sd["state_dict"]
24
+ else:
25
+ sd = pl_sd
26
+ model = instantiate_from_config(config.model)
27
+ m, u = model.load_state_dict(sd, strict=False)
28
+ # if len(m) > 0 and verbose:
29
+ # print("missing keys:")
30
+ # print(m)
31
+ # if len(u) > 0 and verbose:
32
+ # print("unexpected keys:")
33
+ # print(u)
34
+
35
+ model.cuda()
36
+ model.eval()
37
+ return model