pranavajay commited on
Commit
639a33b
1 Parent(s): 9f86643

Create tr.py

Browse files
Files changed (1) hide show
  1. tr.py +25 -0
tr.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import FluxPipeline
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # Ensure 'output' folder exists, if not, create it
7
+ output_folder = './output/flux_8step'
8
+ os.makedirs(output_folder, exist_ok=True)
9
+
10
+ # Load base model
11
+ base_model_id = "trongg/Flux-Dev2Pro_nsfw_fluxtastic-v3"
12
+ repo_name = "ByteDance/Hyper-SD"
13
+ ckpt_name = "Hyper-FLUX.1-dev-8steps-lora.safetensors"
14
+ pipe = FluxPipeline.from_pretrained(base_model_id)
15
+
16
+ # Load and fuse LoRA weights
17
+ pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
18
+ pipe.fuse_lora(lora_scale=0.125)
19
+
20
+ # Unload LoRA to return the model to its original state
21
+ pipe.unload_lora_weights()
22
+
23
+ # Save the transformer model in 'output' folder
24
+ model = pipe.transformer
25
+ model.save_pretrained(output_folder)