ritutweets46 Aditibaheti commited on
Commit
c2633fb
1 Parent(s): 58a0d29

updates 2 (#2)

Browse files

- updates 2 (15c309de6c0902cfc09fca32e94107b50ade0e0e)


Co-authored-by: Aditi Baheti <Aditibaheti@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -6,17 +6,15 @@ import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Replace 'path/to/your/safetensors' with the actual path to your fine-tuned model's safetensors file
10
- model_path = "./pytorch_lora_weights.safetensors"
 
11
 
12
- if torch.cuda.is_available():
13
- torch.cuda.max_memory_allocated(device=device)
14
- pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
15
- pipe.enable_xformers_memory_efficient_attention()
16
- pipe = pipe.to(device)
17
- else:
18
- pipe = DiffusionPipeline.from_pretrained(model_path, use_safetensors=True)
19
- pipe = pipe.to(device)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
@@ -25,9 +23,9 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
25
  if randomize_seed:
26
  seed = random.randint(0, MAX_SEED)
27
 
28
- generator = torch.Generator().manual_seed(seed)
29
 
30
- image = pipe(
31
  prompt=prompt,
32
  negative_prompt=negative_prompt,
33
  guidance_scale=guidance_scale,
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Path to your model repository and safetensors weights
10
+ base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
11
+ lora_weights_path = "./pytorch_lora_weights.safetensors"
12
 
13
+ # Load the base model
14
+ pipeline = DiffusionPipeline.from_pretrained(base_model_repo, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
15
+ pipeline.load_lora_weights(lora_weights_path)
16
+ pipeline.enable_sequential_cpu_offload()
17
+ pipeline = pipeline.to(device)
 
 
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
 
26
+ generator = torch.Generator(device=device).manual_seed(seed)
27
 
28
+ image = pipeline(
29
  prompt=prompt,
30
  negative_prompt=negative_prompt,
31
  guidance_scale=guidance_scale,