|
|
|
|
|
|
|
|
|
import streamlit as st |
|
import requests |
|
from PIL import Image |
|
import io |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_api_key = os.getenv('HF_API_KEY') |
|
if not hf_api_key: |
|
raise ValueError("HF_API_KEY not set in environment variables") |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" |
|
headers = {"Authorization": f"Bearer {hf_api_key}"} |
|
|
|
|
|
|
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
if response.status_code != 200: |
|
st.error(f"Error: {response.status_code} - {response.text}") |
|
return None |
|
return response.content |
|
|
|
def generate_image(prompt): |
|
image_bytes = query({"inputs": prompt}) |
|
if image_bytes: |
|
return Image.open(io.BytesIO(image_bytes)) |
|
return None |
|
|
|
def main(): |
|
st.title("Stable Diffusion XL 1.0") |
|
|
|
prompt = st.text_input("Enter a prompt for image generation:") |
|
|
|
if st.button("Generate Image"): |
|
if prompt: |
|
image = generate_image(prompt) |
|
if image: |
|
st.image(image, caption="Generated Image") |
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |