ZhengPeng7 commited on
Commit
b305b72
1 Parent(s): d6ffbb8

Add guideline to use BiRefNet. Remove codes of model.

Browse files
Files changed (1) hide show
  1. README.md +6 -3
README.md CHANGED
@@ -31,9 +31,12 @@ import matplotlib.pyplot as plt
31
  import torch
32
  from torchvision import transforms
33
 
 
 
 
34
  # Input Data
35
  transform_image = transforms.Compose([
36
- transforms.Resize((256, 256)),
37
  transforms.ToTensor(),
38
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
  ])
@@ -42,7 +45,7 @@ image = Image.open(imagepath)
42
  input_images = transform_image(image).unsqueeze(0).to('cuda')
43
 
44
  # Load Model
45
- device = '0'
46
  torch.set_float32_matmul_precision(['high', 'highest'][0])
47
  model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
48
  model.to(device)
@@ -55,7 +58,7 @@ with torch.no_grad():
55
  pred = preds[0].squeeze()
56
 
57
  # Show Results
58
- plt.imshow(pred, cmap='gray')
59
  plt.show()
60
 
61
  ```
 
31
  import torch
32
  from torchvision import transforms
33
 
34
+ from models.birefnet import BiRefNet
35
+
36
+
37
  # Input Data
38
  transform_image = transforms.Compose([
39
+ transforms.Resize((1024, 1024)),
40
  transforms.ToTensor(),
41
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
  ])
 
45
  input_images = transform_image(image).unsqueeze(0).to('cuda')
46
 
47
  # Load Model
48
+ device = 'cuda'
49
  torch.set_float32_matmul_precision(['high', 'highest'][0])
50
  model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
51
  model.to(device)
 
58
  pred = preds[0].squeeze()
59
 
60
  # Show Results
61
+ plt.imshow(transforms.ToPILImage()(pred).resize(image.size), cmap='gray')
62
  plt.show()
63
 
64
  ```