K00B404 commited on
Commit
d6b6b95
1 Parent(s): ee09d9a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -2
README.md CHANGED
@@ -1,7 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Pix2Pix UNet Model
2
-
 
3
  - **Image Size:** 1024
4
  - **Model Type:** big_UNet (1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  ## Model Architecture
6
  UNet(
7
  (encoder): Sequential(
@@ -28,4 +55,4 @@ UNet(
28
  (8): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
29
  (9): Tanh()
30
  )
31
- )
 
1
+ ---
2
+ tags:
3
+ - unet
4
+ - pix2pix
5
+ - pytorch
6
+ library_name: pytorch
7
+ license: wtfpl
8
+ datasets:
9
+ - K00B404/pix2pix_flux_set
10
+ language:
11
+ - en
12
+ pipeline_tag: image-to-image
13
+ ---
14
  # Pix2Pix UNet Model
15
+ ## Model Description
16
+ Custom UNet model for Pix2Pix image translation.
17
  - **Image Size:** 1024
18
  - **Model Type:** big_UNet (1024)
19
+ ## Usage
20
+ ```python
21
+ import torch
22
+ from small_256_model import UNet as small_UNet
23
+ from big_1024_model import UNet as big_UNet
24
+ big = True
25
+ # Load the model
26
+ name='big_model_weights.pth' if big else 'small_model_weights.pth'
27
+ checkpoint = torch.load(name)
28
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
29
+ model.load_state_dict(checkpoint['model_state_dict'])
30
+ model.eval()
31
+ ```
32
  ## Model Architecture
33
  UNet(
34
  (encoder): Sequential(
 
55
  (8): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
56
  (9): Tanh()
57
  )
58
+ )