diff --git a/.gitattributes b/.gitattributes
index cad82a7d5c38d2c1723a78dc6ac0ea68ecc7bf88..a6344aac8c09253b3b630fb776ae94478aa0275b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,5 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
-*.json filter=lfs diff=lfs merge=lfs -text
-*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7e99e367f8443d86e5e8825b9fda39dfbb39630d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+*.pyc
\ No newline at end of file
diff --git a/README.md b/README.md
index f905b22e8fbc1f5bb61bdbfcc9c5ca997c0f2a0a..d50c6a7d800d200187b6e0256028b4860b2dbb41 100644
--- a/README.md
+++ b/README.md
@@ -5,8 +5,9 @@ colorFrom: blue
colorTo: purple
sdk: gradio
sdk_version: "4.24.0"
-app_file: gradio_demo/app.py
+app_file: app.py
pinned: false
+short_description: Virtual Try-on
---
# Virtual Try-On Demo
@@ -18,5 +19,5 @@ This application is a Virtual Try-On model demonstration powered by Gradio. It a
To start the app locally, use the following command:
```bash
-python gradio_demo/app.py
+python app.py
```
diff --git a/gradio_demo/app.py b/app.py
similarity index 96%
rename from gradio_demo/app.py
rename to app.py
index a25ebd10fc3b9ba636884a982d5e46856b089e39..186494d37684b0f381892c11f2e1e2ae10aeee91 100644
--- a/gradio_demo/app.py
+++ b/app.py
@@ -1,7 +1,6 @@
-import sys
-sys.path.append('./')
-from PIL import Image
import gradio as gr
+import spaces
+from PIL import Image
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
@@ -26,7 +25,6 @@ from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
from torchvision.transforms.functional import to_pil_image
-device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def pil_to_binary_mask(pil_image, threshold=0):
np_image = np.array(pil_image)
@@ -123,7 +121,9 @@ pipe = TryonPipeline.from_pretrained(
)
pipe.unet_encoder = UNet_Encoder
+@spaces.GPU
def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
+ device = "cuda"
openpose_model.preprocessor.body_estimation.model.to(device)
pipe.to(device)
@@ -258,10 +258,10 @@ for ex_human in human_list_path:
##default human
-image_blocks = gr.Blocks().queue()
+image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue()
with image_blocks as demo:
- gr.Markdown("## IDM-VTON 👕👔👚")
- gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)")
+ gr.HTML("
Virtual Try-On
")
+ gr.HTML("Upload an image of a person and an image of a garment ✨
")
with gr.Row():
with gr.Column():
imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
@@ -309,5 +309,4 @@ with image_blocks as demo:
-image_blocks.launch(share=True)
-
+image_blocks.launch()
\ No newline at end of file
diff --git a/gradio_demo/apply_net.py b/apply_net.py
old mode 100644
new mode 100755
similarity index 100%
rename from gradio_demo/apply_net.py
rename to apply_net.py
diff --git a/assets/teaser.png b/assets/teaser.png
deleted file mode 100644
index 931bc3b52dd73cc4763756c59393e5850d1bb97e..0000000000000000000000000000000000000000
--- a/assets/teaser.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e0ff5c96023ddf67864dc49acde2fab6a0c982fd77aa4979d9a2e77f45ad0b82
-size 7055496
diff --git a/assets/teaser2.png b/assets/teaser2.png
deleted file mode 100644
index ad35f385e1cbeae96910d6c9d5ecb23084303302..0000000000000000000000000000000000000000
--- a/assets/teaser2.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4a2c3522cb7805407f437f1639418166477f334cbef739e06947b5dfc68a1968
-size 9020741
diff --git a/ckpt/humanparsing/parsing_atr.onnx b/ckpt/humanparsing/parsing_atr.onnx
index 3aa341342d9ccd412b6fff79f9b0315d7dc6c3c0..28883cf4b0069c96f0f00930798428017425c3fa 100644
--- a/ckpt/humanparsing/parsing_atr.onnx
+++ b/ckpt/humanparsing/parsing_atr.onnx
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:3a248ff77aea7799b1a5ad036837e0b50c59b5fb95a2970e28d86b0a31b0cc73
-size 25
+oid sha256:04c7d1d070d0e0ae943d86b18cb5aaaea9e278d97462e9cfb270cbbe4cd977f4
+size 266859305
diff --git a/ckpt/humanparsing/parsing_lip.onnx b/ckpt/humanparsing/parsing_lip.onnx
index 3decb0d867f96439ca0fda704626848b9ab667b3..7d1a879fa30fc002188b0c9fec3cc05064dd1093 100644
--- a/ckpt/humanparsing/parsing_lip.onnx
+++ b/ckpt/humanparsing/parsing_lip.onnx
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:67f9c02ab458afa087bc02e18a9df5e479aebe48e291f46e9de6452cc3b97d37
-size 25
+oid sha256:8436e1dae96e2601c373d1ace29c8f0978b16357d9038c17a8ba756cca376dbc
+size 266863411
diff --git a/ckpt/image_encoder/config.json b/ckpt/image_encoder/config.json
deleted file mode 100644
index a6b2b6819e723b53e39363134706ffdca74f5179..0000000000000000000000000000000000000000
--- a/ckpt/image_encoder/config.json
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:625d37b31afbf2f0792a87846b3654ee23f20568409e35b78a1f795b04e1a7a1
-size 560
diff --git a/ckpt/image_encoder/model.safetensors b/ckpt/image_encoder/model.safetensors
deleted file mode 100644
index 7b6f224d25646c23132235975f42511ac2920549..0000000000000000000000000000000000000000
--- a/ckpt/image_encoder/model.safetensors
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:ff4265338f828df6c74782f6e1614a5c01d204e7f8bfd46eafa013f3de151d0e
-size 27
diff --git a/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin b/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin
deleted file mode 100644
index 45555d22ef8f4eb9ecef7e461142ba39a11151a8..0000000000000000000000000000000000000000
--- a/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e8a955fca80fb4ba5718156d2e2619c9fbcd323d7748f44c7d4ce5503763baa0
-size 24
diff --git a/ckpt/openpose/ckpts/body_pose_model.pth b/ckpt/openpose/ckpts/body_pose_model.pth
index 774e6e843d5f7e3ff84eaa3f7d17b048860ddf55..9acb77e68f31906a8875f1daef2f3f7ef94acb1e 100644
--- a/ckpt/openpose/ckpts/body_pose_model.pth
+++ b/ckpt/openpose/ckpts/body_pose_model.pth
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:1014e5f51f5e838d02528a05837437f0f7424a9bfb90c2d6a55f0c73db2c0e0f
-size 28
+oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746
+size 209267595
diff --git a/gradio_demo/densepose/__init__.py b/densepose/__init__.py
similarity index 100%
rename from gradio_demo/densepose/__init__.py
rename to densepose/__init__.py
diff --git a/gradio_demo/densepose/config.py b/densepose/config.py
similarity index 100%
rename from gradio_demo/densepose/config.py
rename to densepose/config.py
diff --git a/gradio_demo/densepose/converters/__init__.py b/densepose/converters/__init__.py
similarity index 100%
rename from gradio_demo/densepose/converters/__init__.py
rename to densepose/converters/__init__.py
diff --git a/gradio_demo/densepose/converters/base.py b/densepose/converters/base.py
similarity index 100%
rename from gradio_demo/densepose/converters/base.py
rename to densepose/converters/base.py
diff --git a/gradio_demo/densepose/converters/builtin.py b/densepose/converters/builtin.py
similarity index 100%
rename from gradio_demo/densepose/converters/builtin.py
rename to densepose/converters/builtin.py
diff --git a/gradio_demo/densepose/converters/chart_output_hflip.py b/densepose/converters/chart_output_hflip.py
similarity index 100%
rename from gradio_demo/densepose/converters/chart_output_hflip.py
rename to densepose/converters/chart_output_hflip.py
diff --git a/gradio_demo/densepose/converters/chart_output_to_chart_result.py b/densepose/converters/chart_output_to_chart_result.py
similarity index 100%
rename from gradio_demo/densepose/converters/chart_output_to_chart_result.py
rename to densepose/converters/chart_output_to_chart_result.py
diff --git a/gradio_demo/densepose/converters/hflip.py b/densepose/converters/hflip.py
similarity index 100%
rename from gradio_demo/densepose/converters/hflip.py
rename to densepose/converters/hflip.py
diff --git a/gradio_demo/densepose/converters/segm_to_mask.py b/densepose/converters/segm_to_mask.py
similarity index 100%
rename from gradio_demo/densepose/converters/segm_to_mask.py
rename to densepose/converters/segm_to_mask.py
diff --git a/gradio_demo/densepose/converters/to_chart_result.py b/densepose/converters/to_chart_result.py
similarity index 100%
rename from gradio_demo/densepose/converters/to_chart_result.py
rename to densepose/converters/to_chart_result.py
diff --git a/gradio_demo/densepose/converters/to_mask.py b/densepose/converters/to_mask.py
similarity index 100%
rename from gradio_demo/densepose/converters/to_mask.py
rename to densepose/converters/to_mask.py
diff --git a/gradio_demo/densepose/data/__init__.py b/densepose/data/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/__init__.py
rename to densepose/data/__init__.py
diff --git a/gradio_demo/densepose/data/build.py b/densepose/data/build.py
similarity index 100%
rename from gradio_demo/densepose/data/build.py
rename to densepose/data/build.py
diff --git a/gradio_demo/densepose/data/combined_loader.py b/densepose/data/combined_loader.py
similarity index 100%
rename from gradio_demo/densepose/data/combined_loader.py
rename to densepose/data/combined_loader.py
diff --git a/gradio_demo/densepose/data/dataset_mapper.py b/densepose/data/dataset_mapper.py
similarity index 100%
rename from gradio_demo/densepose/data/dataset_mapper.py
rename to densepose/data/dataset_mapper.py
diff --git a/gradio_demo/densepose/data/datasets/__init__.py b/densepose/data/datasets/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/__init__.py
rename to densepose/data/datasets/__init__.py
diff --git a/gradio_demo/densepose/data/datasets/builtin.py b/densepose/data/datasets/builtin.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/builtin.py
rename to densepose/data/datasets/builtin.py
diff --git a/gradio_demo/densepose/data/datasets/chimpnsee.py b/densepose/data/datasets/chimpnsee.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/chimpnsee.py
rename to densepose/data/datasets/chimpnsee.py
diff --git a/gradio_demo/densepose/data/datasets/coco.py b/densepose/data/datasets/coco.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/coco.py
rename to densepose/data/datasets/coco.py
diff --git a/gradio_demo/densepose/data/datasets/dataset_type.py b/densepose/data/datasets/dataset_type.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/dataset_type.py
rename to densepose/data/datasets/dataset_type.py
diff --git a/gradio_demo/densepose/data/datasets/lvis.py b/densepose/data/datasets/lvis.py
similarity index 100%
rename from gradio_demo/densepose/data/datasets/lvis.py
rename to densepose/data/datasets/lvis.py
diff --git a/gradio_demo/densepose/data/image_list_dataset.py b/densepose/data/image_list_dataset.py
similarity index 100%
rename from gradio_demo/densepose/data/image_list_dataset.py
rename to densepose/data/image_list_dataset.py
diff --git a/gradio_demo/densepose/data/inference_based_loader.py b/densepose/data/inference_based_loader.py
similarity index 100%
rename from gradio_demo/densepose/data/inference_based_loader.py
rename to densepose/data/inference_based_loader.py
diff --git a/gradio_demo/densepose/data/meshes/__init__.py b/densepose/data/meshes/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/meshes/__init__.py
rename to densepose/data/meshes/__init__.py
diff --git a/gradio_demo/densepose/data/meshes/builtin.py b/densepose/data/meshes/builtin.py
similarity index 100%
rename from gradio_demo/densepose/data/meshes/builtin.py
rename to densepose/data/meshes/builtin.py
diff --git a/gradio_demo/densepose/data/meshes/catalog.py b/densepose/data/meshes/catalog.py
similarity index 100%
rename from gradio_demo/densepose/data/meshes/catalog.py
rename to densepose/data/meshes/catalog.py
diff --git a/gradio_demo/densepose/data/samplers/__init__.py b/densepose/data/samplers/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/__init__.py
rename to densepose/data/samplers/__init__.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_base.py b/densepose/data/samplers/densepose_base.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_base.py
rename to densepose/data/samplers/densepose_base.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_confidence_based.py b/densepose/data/samplers/densepose_confidence_based.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_confidence_based.py
rename to densepose/data/samplers/densepose_confidence_based.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_base.py b/densepose/data/samplers/densepose_cse_base.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_cse_base.py
rename to densepose/data/samplers/densepose_cse_base.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_confidence_based.py b/densepose/data/samplers/densepose_cse_confidence_based.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_cse_confidence_based.py
rename to densepose/data/samplers/densepose_cse_confidence_based.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_uniform.py b/densepose/data/samplers/densepose_cse_uniform.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_cse_uniform.py
rename to densepose/data/samplers/densepose_cse_uniform.py
diff --git a/gradio_demo/densepose/data/samplers/densepose_uniform.py b/densepose/data/samplers/densepose_uniform.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/densepose_uniform.py
rename to densepose/data/samplers/densepose_uniform.py
diff --git a/gradio_demo/densepose/data/samplers/mask_from_densepose.py b/densepose/data/samplers/mask_from_densepose.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/mask_from_densepose.py
rename to densepose/data/samplers/mask_from_densepose.py
diff --git a/gradio_demo/densepose/data/samplers/prediction_to_gt.py b/densepose/data/samplers/prediction_to_gt.py
similarity index 100%
rename from gradio_demo/densepose/data/samplers/prediction_to_gt.py
rename to densepose/data/samplers/prediction_to_gt.py
diff --git a/gradio_demo/densepose/data/transform/__init__.py b/densepose/data/transform/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/transform/__init__.py
rename to densepose/data/transform/__init__.py
diff --git a/gradio_demo/densepose/data/transform/image.py b/densepose/data/transform/image.py
similarity index 100%
rename from gradio_demo/densepose/data/transform/image.py
rename to densepose/data/transform/image.py
diff --git a/gradio_demo/densepose/data/utils.py b/densepose/data/utils.py
similarity index 100%
rename from gradio_demo/densepose/data/utils.py
rename to densepose/data/utils.py
diff --git a/gradio_demo/densepose/data/video/__init__.py b/densepose/data/video/__init__.py
similarity index 100%
rename from gradio_demo/densepose/data/video/__init__.py
rename to densepose/data/video/__init__.py
diff --git a/gradio_demo/densepose/data/video/frame_selector.py b/densepose/data/video/frame_selector.py
similarity index 100%
rename from gradio_demo/densepose/data/video/frame_selector.py
rename to densepose/data/video/frame_selector.py
diff --git a/gradio_demo/densepose/data/video/video_keyframe_dataset.py b/densepose/data/video/video_keyframe_dataset.py
similarity index 100%
rename from gradio_demo/densepose/data/video/video_keyframe_dataset.py
rename to densepose/data/video/video_keyframe_dataset.py
diff --git a/gradio_demo/densepose/engine/__init__.py b/densepose/engine/__init__.py
similarity index 100%
rename from gradio_demo/densepose/engine/__init__.py
rename to densepose/engine/__init__.py
diff --git a/gradio_demo/densepose/engine/trainer.py b/densepose/engine/trainer.py
similarity index 100%
rename from gradio_demo/densepose/engine/trainer.py
rename to densepose/engine/trainer.py
diff --git a/gradio_demo/densepose/evaluation/__init__.py b/densepose/evaluation/__init__.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/__init__.py
rename to densepose/evaluation/__init__.py
diff --git a/gradio_demo/densepose/evaluation/d2_evaluator_adapter.py b/densepose/evaluation/d2_evaluator_adapter.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/d2_evaluator_adapter.py
rename to densepose/evaluation/d2_evaluator_adapter.py
diff --git a/gradio_demo/densepose/evaluation/densepose_coco_evaluation.py b/densepose/evaluation/densepose_coco_evaluation.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/densepose_coco_evaluation.py
rename to densepose/evaluation/densepose_coco_evaluation.py
diff --git a/gradio_demo/densepose/evaluation/evaluator.py b/densepose/evaluation/evaluator.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/evaluator.py
rename to densepose/evaluation/evaluator.py
diff --git a/gradio_demo/densepose/evaluation/mesh_alignment_evaluator.py b/densepose/evaluation/mesh_alignment_evaluator.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/mesh_alignment_evaluator.py
rename to densepose/evaluation/mesh_alignment_evaluator.py
diff --git a/gradio_demo/densepose/evaluation/tensor_storage.py b/densepose/evaluation/tensor_storage.py
similarity index 100%
rename from gradio_demo/densepose/evaluation/tensor_storage.py
rename to densepose/evaluation/tensor_storage.py
diff --git a/gradio_demo/densepose/modeling/__init__.py b/densepose/modeling/__init__.py
similarity index 100%
rename from gradio_demo/densepose/modeling/__init__.py
rename to densepose/modeling/__init__.py
diff --git a/gradio_demo/densepose/modeling/build.py b/densepose/modeling/build.py
similarity index 100%
rename from gradio_demo/densepose/modeling/build.py
rename to densepose/modeling/build.py
diff --git a/gradio_demo/densepose/modeling/confidence.py b/densepose/modeling/confidence.py
similarity index 100%
rename from gradio_demo/densepose/modeling/confidence.py
rename to densepose/modeling/confidence.py
diff --git a/gradio_demo/densepose/modeling/cse/__init__.py b/densepose/modeling/cse/__init__.py
similarity index 100%
rename from gradio_demo/densepose/modeling/cse/__init__.py
rename to densepose/modeling/cse/__init__.py
diff --git a/gradio_demo/densepose/modeling/cse/embedder.py b/densepose/modeling/cse/embedder.py
similarity index 100%
rename from gradio_demo/densepose/modeling/cse/embedder.py
rename to densepose/modeling/cse/embedder.py
diff --git a/gradio_demo/densepose/modeling/cse/utils.py b/densepose/modeling/cse/utils.py
similarity index 100%
rename from gradio_demo/densepose/modeling/cse/utils.py
rename to densepose/modeling/cse/utils.py
diff --git a/gradio_demo/densepose/modeling/cse/vertex_direct_embedder.py b/densepose/modeling/cse/vertex_direct_embedder.py
similarity index 100%
rename from gradio_demo/densepose/modeling/cse/vertex_direct_embedder.py
rename to densepose/modeling/cse/vertex_direct_embedder.py
diff --git a/gradio_demo/densepose/modeling/cse/vertex_feature_embedder.py b/densepose/modeling/cse/vertex_feature_embedder.py
similarity index 100%
rename from gradio_demo/densepose/modeling/cse/vertex_feature_embedder.py
rename to densepose/modeling/cse/vertex_feature_embedder.py
diff --git a/gradio_demo/densepose/modeling/densepose_checkpoint.py b/densepose/modeling/densepose_checkpoint.py
similarity index 100%
rename from gradio_demo/densepose/modeling/densepose_checkpoint.py
rename to densepose/modeling/densepose_checkpoint.py
diff --git a/gradio_demo/densepose/modeling/filter.py b/densepose/modeling/filter.py
similarity index 100%
rename from gradio_demo/densepose/modeling/filter.py
rename to densepose/modeling/filter.py
diff --git a/gradio_demo/densepose/modeling/hrfpn.py b/densepose/modeling/hrfpn.py
similarity index 100%
rename from gradio_demo/densepose/modeling/hrfpn.py
rename to densepose/modeling/hrfpn.py
diff --git a/gradio_demo/densepose/modeling/hrnet.py b/densepose/modeling/hrnet.py
similarity index 100%
rename from gradio_demo/densepose/modeling/hrnet.py
rename to densepose/modeling/hrnet.py
diff --git a/gradio_demo/densepose/modeling/inference.py b/densepose/modeling/inference.py
similarity index 100%
rename from gradio_demo/densepose/modeling/inference.py
rename to densepose/modeling/inference.py
diff --git a/gradio_demo/densepose/modeling/losses/__init__.py b/densepose/modeling/losses/__init__.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/__init__.py
rename to densepose/modeling/losses/__init__.py
diff --git a/gradio_demo/densepose/modeling/losses/chart.py b/densepose/modeling/losses/chart.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/chart.py
rename to densepose/modeling/losses/chart.py
diff --git a/gradio_demo/densepose/modeling/losses/chart_with_confidences.py b/densepose/modeling/losses/chart_with_confidences.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/chart_with_confidences.py
rename to densepose/modeling/losses/chart_with_confidences.py
diff --git a/gradio_demo/densepose/modeling/losses/cse.py b/densepose/modeling/losses/cse.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/cse.py
rename to densepose/modeling/losses/cse.py
diff --git a/gradio_demo/densepose/modeling/losses/cycle_pix2shape.py b/densepose/modeling/losses/cycle_pix2shape.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/cycle_pix2shape.py
rename to densepose/modeling/losses/cycle_pix2shape.py
diff --git a/gradio_demo/densepose/modeling/losses/cycle_shape2shape.py b/densepose/modeling/losses/cycle_shape2shape.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/cycle_shape2shape.py
rename to densepose/modeling/losses/cycle_shape2shape.py
diff --git a/gradio_demo/densepose/modeling/losses/embed.py b/densepose/modeling/losses/embed.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/embed.py
rename to densepose/modeling/losses/embed.py
diff --git a/gradio_demo/densepose/modeling/losses/embed_utils.py b/densepose/modeling/losses/embed_utils.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/embed_utils.py
rename to densepose/modeling/losses/embed_utils.py
diff --git a/gradio_demo/densepose/modeling/losses/mask.py b/densepose/modeling/losses/mask.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/mask.py
rename to densepose/modeling/losses/mask.py
diff --git a/gradio_demo/densepose/modeling/losses/mask_or_segm.py b/densepose/modeling/losses/mask_or_segm.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/mask_or_segm.py
rename to densepose/modeling/losses/mask_or_segm.py
diff --git a/gradio_demo/densepose/modeling/losses/registry.py b/densepose/modeling/losses/registry.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/registry.py
rename to densepose/modeling/losses/registry.py
diff --git a/gradio_demo/densepose/modeling/losses/segm.py b/densepose/modeling/losses/segm.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/segm.py
rename to densepose/modeling/losses/segm.py
diff --git a/gradio_demo/densepose/modeling/losses/soft_embed.py b/densepose/modeling/losses/soft_embed.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/soft_embed.py
rename to densepose/modeling/losses/soft_embed.py
diff --git a/gradio_demo/densepose/modeling/losses/utils.py b/densepose/modeling/losses/utils.py
similarity index 100%
rename from gradio_demo/densepose/modeling/losses/utils.py
rename to densepose/modeling/losses/utils.py
diff --git a/gradio_demo/densepose/modeling/predictors/__init__.py b/densepose/modeling/predictors/__init__.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/__init__.py
rename to densepose/modeling/predictors/__init__.py
diff --git a/gradio_demo/densepose/modeling/predictors/chart.py b/densepose/modeling/predictors/chart.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/chart.py
rename to densepose/modeling/predictors/chart.py
diff --git a/gradio_demo/densepose/modeling/predictors/chart_confidence.py b/densepose/modeling/predictors/chart_confidence.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/chart_confidence.py
rename to densepose/modeling/predictors/chart_confidence.py
diff --git a/gradio_demo/densepose/modeling/predictors/chart_with_confidence.py b/densepose/modeling/predictors/chart_with_confidence.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/chart_with_confidence.py
rename to densepose/modeling/predictors/chart_with_confidence.py
diff --git a/gradio_demo/densepose/modeling/predictors/cse.py b/densepose/modeling/predictors/cse.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/cse.py
rename to densepose/modeling/predictors/cse.py
diff --git a/gradio_demo/densepose/modeling/predictors/cse_confidence.py b/densepose/modeling/predictors/cse_confidence.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/cse_confidence.py
rename to densepose/modeling/predictors/cse_confidence.py
diff --git a/gradio_demo/densepose/modeling/predictors/cse_with_confidence.py b/densepose/modeling/predictors/cse_with_confidence.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/cse_with_confidence.py
rename to densepose/modeling/predictors/cse_with_confidence.py
diff --git a/gradio_demo/densepose/modeling/predictors/registry.py b/densepose/modeling/predictors/registry.py
similarity index 100%
rename from gradio_demo/densepose/modeling/predictors/registry.py
rename to densepose/modeling/predictors/registry.py
diff --git a/gradio_demo/densepose/modeling/roi_heads/__init__.py b/densepose/modeling/roi_heads/__init__.py
similarity index 100%
rename from gradio_demo/densepose/modeling/roi_heads/__init__.py
rename to densepose/modeling/roi_heads/__init__.py
diff --git a/gradio_demo/densepose/modeling/roi_heads/deeplab.py b/densepose/modeling/roi_heads/deeplab.py
similarity index 100%
rename from gradio_demo/densepose/modeling/roi_heads/deeplab.py
rename to densepose/modeling/roi_heads/deeplab.py
diff --git a/gradio_demo/densepose/modeling/roi_heads/registry.py b/densepose/modeling/roi_heads/registry.py
similarity index 100%
rename from gradio_demo/densepose/modeling/roi_heads/registry.py
rename to densepose/modeling/roi_heads/registry.py
diff --git a/gradio_demo/densepose/modeling/roi_heads/roi_head.py b/densepose/modeling/roi_heads/roi_head.py
similarity index 100%
rename from gradio_demo/densepose/modeling/roi_heads/roi_head.py
rename to densepose/modeling/roi_heads/roi_head.py
diff --git a/gradio_demo/densepose/modeling/roi_heads/v1convx.py b/densepose/modeling/roi_heads/v1convx.py
similarity index 100%
rename from gradio_demo/densepose/modeling/roi_heads/v1convx.py
rename to densepose/modeling/roi_heads/v1convx.py
diff --git a/gradio_demo/densepose/modeling/test_time_augmentation.py b/densepose/modeling/test_time_augmentation.py
similarity index 100%
rename from gradio_demo/densepose/modeling/test_time_augmentation.py
rename to densepose/modeling/test_time_augmentation.py
diff --git a/gradio_demo/densepose/modeling/utils.py b/densepose/modeling/utils.py
similarity index 100%
rename from gradio_demo/densepose/modeling/utils.py
rename to densepose/modeling/utils.py
diff --git a/gradio_demo/densepose/structures/__init__.py b/densepose/structures/__init__.py
similarity index 100%
rename from gradio_demo/densepose/structures/__init__.py
rename to densepose/structures/__init__.py
diff --git a/gradio_demo/densepose/structures/chart.py b/densepose/structures/chart.py
similarity index 100%
rename from gradio_demo/densepose/structures/chart.py
rename to densepose/structures/chart.py
diff --git a/gradio_demo/densepose/structures/chart_confidence.py b/densepose/structures/chart_confidence.py
similarity index 100%
rename from gradio_demo/densepose/structures/chart_confidence.py
rename to densepose/structures/chart_confidence.py
diff --git a/gradio_demo/densepose/structures/chart_result.py b/densepose/structures/chart_result.py
similarity index 100%
rename from gradio_demo/densepose/structures/chart_result.py
rename to densepose/structures/chart_result.py
diff --git a/gradio_demo/densepose/structures/cse.py b/densepose/structures/cse.py
similarity index 100%
rename from gradio_demo/densepose/structures/cse.py
rename to densepose/structures/cse.py
diff --git a/gradio_demo/densepose/structures/cse_confidence.py b/densepose/structures/cse_confidence.py
similarity index 100%
rename from gradio_demo/densepose/structures/cse_confidence.py
rename to densepose/structures/cse_confidence.py
diff --git a/gradio_demo/densepose/structures/data_relative.py b/densepose/structures/data_relative.py
similarity index 100%
rename from gradio_demo/densepose/structures/data_relative.py
rename to densepose/structures/data_relative.py
diff --git a/gradio_demo/densepose/structures/list.py b/densepose/structures/list.py
similarity index 100%
rename from gradio_demo/densepose/structures/list.py
rename to densepose/structures/list.py
diff --git a/gradio_demo/densepose/structures/mesh.py b/densepose/structures/mesh.py
similarity index 100%
rename from gradio_demo/densepose/structures/mesh.py
rename to densepose/structures/mesh.py
diff --git a/gradio_demo/densepose/structures/transform_data.py b/densepose/structures/transform_data.py
similarity index 100%
rename from gradio_demo/densepose/structures/transform_data.py
rename to densepose/structures/transform_data.py
diff --git a/gradio_demo/densepose/utils/__init__.py b/densepose/utils/__init__.py
similarity index 100%
rename from gradio_demo/densepose/utils/__init__.py
rename to densepose/utils/__init__.py
diff --git a/gradio_demo/densepose/utils/dbhelper.py b/densepose/utils/dbhelper.py
similarity index 100%
rename from gradio_demo/densepose/utils/dbhelper.py
rename to densepose/utils/dbhelper.py
diff --git a/gradio_demo/densepose/utils/logger.py b/densepose/utils/logger.py
similarity index 100%
rename from gradio_demo/densepose/utils/logger.py
rename to densepose/utils/logger.py
diff --git a/gradio_demo/densepose/utils/transform.py b/densepose/utils/transform.py
similarity index 100%
rename from gradio_demo/densepose/utils/transform.py
rename to densepose/utils/transform.py
diff --git a/gradio_demo/densepose/vis/__init__.py b/densepose/vis/__init__.py
similarity index 100%
rename from gradio_demo/densepose/vis/__init__.py
rename to densepose/vis/__init__.py
diff --git a/gradio_demo/densepose/vis/base.py b/densepose/vis/base.py
similarity index 100%
rename from gradio_demo/densepose/vis/base.py
rename to densepose/vis/base.py
diff --git a/gradio_demo/densepose/vis/bounding_box.py b/densepose/vis/bounding_box.py
similarity index 100%
rename from gradio_demo/densepose/vis/bounding_box.py
rename to densepose/vis/bounding_box.py
diff --git a/gradio_demo/densepose/vis/densepose_data_points.py b/densepose/vis/densepose_data_points.py
similarity index 100%
rename from gradio_demo/densepose/vis/densepose_data_points.py
rename to densepose/vis/densepose_data_points.py
diff --git a/gradio_demo/densepose/vis/densepose_outputs_iuv.py b/densepose/vis/densepose_outputs_iuv.py
similarity index 100%
rename from gradio_demo/densepose/vis/densepose_outputs_iuv.py
rename to densepose/vis/densepose_outputs_iuv.py
diff --git a/gradio_demo/densepose/vis/densepose_outputs_vertex.py b/densepose/vis/densepose_outputs_vertex.py
similarity index 100%
rename from gradio_demo/densepose/vis/densepose_outputs_vertex.py
rename to densepose/vis/densepose_outputs_vertex.py
diff --git a/gradio_demo/densepose/vis/densepose_results.py b/densepose/vis/densepose_results.py
similarity index 100%
rename from gradio_demo/densepose/vis/densepose_results.py
rename to densepose/vis/densepose_results.py
diff --git a/gradio_demo/densepose/vis/densepose_results_textures.py b/densepose/vis/densepose_results_textures.py
similarity index 100%
rename from gradio_demo/densepose/vis/densepose_results_textures.py
rename to densepose/vis/densepose_results_textures.py
diff --git a/gradio_demo/densepose/vis/extractor.py b/densepose/vis/extractor.py
similarity index 100%
rename from gradio_demo/densepose/vis/extractor.py
rename to densepose/vis/extractor.py
diff --git a/gradio_demo/detectron2/_C.cpython-39-x86_64-linux-gnu.so b/detectron2/_C.cpython-39-x86_64-linux-gnu.so
old mode 100644
new mode 100755
similarity index 100%
rename from gradio_demo/detectron2/_C.cpython-39-x86_64-linux-gnu.so
rename to detectron2/_C.cpython-39-x86_64-linux-gnu.so
diff --git a/gradio_demo/detectron2/__init__.py b/detectron2/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/__init__.py
rename to detectron2/__init__.py
diff --git a/gradio_demo/detectron2/checkpoint/__init__.py b/detectron2/checkpoint/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/checkpoint/__init__.py
rename to detectron2/checkpoint/__init__.py
diff --git a/gradio_demo/detectron2/checkpoint/c2_model_loading.py b/detectron2/checkpoint/c2_model_loading.py
similarity index 100%
rename from gradio_demo/detectron2/checkpoint/c2_model_loading.py
rename to detectron2/checkpoint/c2_model_loading.py
diff --git a/gradio_demo/detectron2/checkpoint/catalog.py b/detectron2/checkpoint/catalog.py
similarity index 100%
rename from gradio_demo/detectron2/checkpoint/catalog.py
rename to detectron2/checkpoint/catalog.py
diff --git a/gradio_demo/detectron2/checkpoint/detection_checkpoint.py b/detectron2/checkpoint/detection_checkpoint.py
similarity index 100%
rename from gradio_demo/detectron2/checkpoint/detection_checkpoint.py
rename to detectron2/checkpoint/detection_checkpoint.py
diff --git a/gradio_demo/detectron2/config/__init__.py b/detectron2/config/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/config/__init__.py
rename to detectron2/config/__init__.py
diff --git a/gradio_demo/detectron2/config/compat.py b/detectron2/config/compat.py
similarity index 100%
rename from gradio_demo/detectron2/config/compat.py
rename to detectron2/config/compat.py
diff --git a/gradio_demo/detectron2/config/config.py b/detectron2/config/config.py
similarity index 100%
rename from gradio_demo/detectron2/config/config.py
rename to detectron2/config/config.py
diff --git a/gradio_demo/detectron2/config/defaults.py b/detectron2/config/defaults.py
similarity index 100%
rename from gradio_demo/detectron2/config/defaults.py
rename to detectron2/config/defaults.py
diff --git a/gradio_demo/detectron2/config/instantiate.py b/detectron2/config/instantiate.py
similarity index 100%
rename from gradio_demo/detectron2/config/instantiate.py
rename to detectron2/config/instantiate.py
diff --git a/gradio_demo/detectron2/config/lazy.py b/detectron2/config/lazy.py
similarity index 100%
rename from gradio_demo/detectron2/config/lazy.py
rename to detectron2/config/lazy.py
diff --git a/gradio_demo/detectron2/data/__init__.py b/detectron2/data/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/data/__init__.py
rename to detectron2/data/__init__.py
diff --git a/gradio_demo/detectron2/data/benchmark.py b/detectron2/data/benchmark.py
similarity index 100%
rename from gradio_demo/detectron2/data/benchmark.py
rename to detectron2/data/benchmark.py
diff --git a/gradio_demo/detectron2/data/build.py b/detectron2/data/build.py
similarity index 100%
rename from gradio_demo/detectron2/data/build.py
rename to detectron2/data/build.py
diff --git a/gradio_demo/detectron2/data/catalog.py b/detectron2/data/catalog.py
similarity index 100%
rename from gradio_demo/detectron2/data/catalog.py
rename to detectron2/data/catalog.py
diff --git a/gradio_demo/detectron2/data/common.py b/detectron2/data/common.py
similarity index 100%
rename from gradio_demo/detectron2/data/common.py
rename to detectron2/data/common.py
diff --git a/gradio_demo/detectron2/data/dataset_mapper.py b/detectron2/data/dataset_mapper.py
similarity index 100%
rename from gradio_demo/detectron2/data/dataset_mapper.py
rename to detectron2/data/dataset_mapper.py
diff --git a/gradio_demo/detectron2/data/datasets/README.md b/detectron2/data/datasets/README.md
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/README.md
rename to detectron2/data/datasets/README.md
diff --git a/gradio_demo/detectron2/data/datasets/__init__.py b/detectron2/data/datasets/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/__init__.py
rename to detectron2/data/datasets/__init__.py
diff --git a/gradio_demo/detectron2/data/datasets/builtin.py b/detectron2/data/datasets/builtin.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/builtin.py
rename to detectron2/data/datasets/builtin.py
diff --git a/gradio_demo/detectron2/data/datasets/builtin_meta.py b/detectron2/data/datasets/builtin_meta.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/builtin_meta.py
rename to detectron2/data/datasets/builtin_meta.py
diff --git a/gradio_demo/detectron2/data/datasets/cityscapes.py b/detectron2/data/datasets/cityscapes.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/cityscapes.py
rename to detectron2/data/datasets/cityscapes.py
diff --git a/gradio_demo/detectron2/data/datasets/cityscapes_panoptic.py b/detectron2/data/datasets/cityscapes_panoptic.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/cityscapes_panoptic.py
rename to detectron2/data/datasets/cityscapes_panoptic.py
diff --git a/gradio_demo/detectron2/data/datasets/coco.py b/detectron2/data/datasets/coco.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/coco.py
rename to detectron2/data/datasets/coco.py
diff --git a/gradio_demo/detectron2/data/datasets/coco_panoptic.py b/detectron2/data/datasets/coco_panoptic.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/coco_panoptic.py
rename to detectron2/data/datasets/coco_panoptic.py
diff --git a/gradio_demo/detectron2/data/datasets/lvis.py b/detectron2/data/datasets/lvis.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/lvis.py
rename to detectron2/data/datasets/lvis.py
diff --git a/gradio_demo/detectron2/data/datasets/lvis_v0_5_categories.py b/detectron2/data/datasets/lvis_v0_5_categories.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/lvis_v0_5_categories.py
rename to detectron2/data/datasets/lvis_v0_5_categories.py
diff --git a/gradio_demo/detectron2/data/datasets/lvis_v1_categories.py b/detectron2/data/datasets/lvis_v1_categories.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/lvis_v1_categories.py
rename to detectron2/data/datasets/lvis_v1_categories.py
diff --git a/gradio_demo/detectron2/data/datasets/lvis_v1_category_image_count.py b/detectron2/data/datasets/lvis_v1_category_image_count.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/lvis_v1_category_image_count.py
rename to detectron2/data/datasets/lvis_v1_category_image_count.py
diff --git a/gradio_demo/detectron2/data/datasets/pascal_voc.py b/detectron2/data/datasets/pascal_voc.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/pascal_voc.py
rename to detectron2/data/datasets/pascal_voc.py
diff --git a/gradio_demo/detectron2/data/datasets/register_coco.py b/detectron2/data/datasets/register_coco.py
similarity index 100%
rename from gradio_demo/detectron2/data/datasets/register_coco.py
rename to detectron2/data/datasets/register_coco.py
diff --git a/gradio_demo/detectron2/data/detection_utils.py b/detectron2/data/detection_utils.py
similarity index 100%
rename from gradio_demo/detectron2/data/detection_utils.py
rename to detectron2/data/detection_utils.py
diff --git a/gradio_demo/detectron2/data/samplers/__init__.py b/detectron2/data/samplers/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/data/samplers/__init__.py
rename to detectron2/data/samplers/__init__.py
diff --git a/gradio_demo/detectron2/data/samplers/distributed_sampler.py b/detectron2/data/samplers/distributed_sampler.py
similarity index 100%
rename from gradio_demo/detectron2/data/samplers/distributed_sampler.py
rename to detectron2/data/samplers/distributed_sampler.py
diff --git a/gradio_demo/detectron2/data/samplers/grouped_batch_sampler.py b/detectron2/data/samplers/grouped_batch_sampler.py
similarity index 100%
rename from gradio_demo/detectron2/data/samplers/grouped_batch_sampler.py
rename to detectron2/data/samplers/grouped_batch_sampler.py
diff --git a/gradio_demo/detectron2/data/transforms/__init__.py b/detectron2/data/transforms/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/data/transforms/__init__.py
rename to detectron2/data/transforms/__init__.py
diff --git a/gradio_demo/detectron2/data/transforms/augmentation.py b/detectron2/data/transforms/augmentation.py
similarity index 100%
rename from gradio_demo/detectron2/data/transforms/augmentation.py
rename to detectron2/data/transforms/augmentation.py
diff --git a/gradio_demo/detectron2/data/transforms/augmentation_impl.py b/detectron2/data/transforms/augmentation_impl.py
similarity index 100%
rename from gradio_demo/detectron2/data/transforms/augmentation_impl.py
rename to detectron2/data/transforms/augmentation_impl.py
diff --git a/gradio_demo/detectron2/data/transforms/transform.py b/detectron2/data/transforms/transform.py
similarity index 100%
rename from gradio_demo/detectron2/data/transforms/transform.py
rename to detectron2/data/transforms/transform.py
diff --git a/gradio_demo/detectron2/engine/__init__.py b/detectron2/engine/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/engine/__init__.py
rename to detectron2/engine/__init__.py
diff --git a/gradio_demo/detectron2/engine/defaults.py b/detectron2/engine/defaults.py
similarity index 100%
rename from gradio_demo/detectron2/engine/defaults.py
rename to detectron2/engine/defaults.py
diff --git a/gradio_demo/detectron2/engine/hooks.py b/detectron2/engine/hooks.py
similarity index 100%
rename from gradio_demo/detectron2/engine/hooks.py
rename to detectron2/engine/hooks.py
diff --git a/gradio_demo/detectron2/engine/launch.py b/detectron2/engine/launch.py
similarity index 100%
rename from gradio_demo/detectron2/engine/launch.py
rename to detectron2/engine/launch.py
diff --git a/gradio_demo/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py
similarity index 100%
rename from gradio_demo/detectron2/engine/train_loop.py
rename to detectron2/engine/train_loop.py
diff --git a/gradio_demo/detectron2/evaluation/__init__.py b/detectron2/evaluation/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/__init__.py
rename to detectron2/evaluation/__init__.py
diff --git a/gradio_demo/detectron2/evaluation/cityscapes_evaluation.py b/detectron2/evaluation/cityscapes_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/cityscapes_evaluation.py
rename to detectron2/evaluation/cityscapes_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/coco_evaluation.py b/detectron2/evaluation/coco_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/coco_evaluation.py
rename to detectron2/evaluation/coco_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/evaluator.py b/detectron2/evaluation/evaluator.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/evaluator.py
rename to detectron2/evaluation/evaluator.py
diff --git a/gradio_demo/detectron2/evaluation/fast_eval_api.py b/detectron2/evaluation/fast_eval_api.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/fast_eval_api.py
rename to detectron2/evaluation/fast_eval_api.py
diff --git a/gradio_demo/detectron2/evaluation/lvis_evaluation.py b/detectron2/evaluation/lvis_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/lvis_evaluation.py
rename to detectron2/evaluation/lvis_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/panoptic_evaluation.py b/detectron2/evaluation/panoptic_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/panoptic_evaluation.py
rename to detectron2/evaluation/panoptic_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/pascal_voc_evaluation.py b/detectron2/evaluation/pascal_voc_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/pascal_voc_evaluation.py
rename to detectron2/evaluation/pascal_voc_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/rotated_coco_evaluation.py b/detectron2/evaluation/rotated_coco_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/rotated_coco_evaluation.py
rename to detectron2/evaluation/rotated_coco_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/sem_seg_evaluation.py b/detectron2/evaluation/sem_seg_evaluation.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/sem_seg_evaluation.py
rename to detectron2/evaluation/sem_seg_evaluation.py
diff --git a/gradio_demo/detectron2/evaluation/testing.py b/detectron2/evaluation/testing.py
similarity index 100%
rename from gradio_demo/detectron2/evaluation/testing.py
rename to detectron2/evaluation/testing.py
diff --git a/gradio_demo/detectron2/export/README.md b/detectron2/export/README.md
similarity index 100%
rename from gradio_demo/detectron2/export/README.md
rename to detectron2/export/README.md
diff --git a/gradio_demo/detectron2/export/__init__.py b/detectron2/export/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/export/__init__.py
rename to detectron2/export/__init__.py
diff --git a/gradio_demo/detectron2/export/api.py b/detectron2/export/api.py
similarity index 100%
rename from gradio_demo/detectron2/export/api.py
rename to detectron2/export/api.py
diff --git a/gradio_demo/detectron2/export/c10.py b/detectron2/export/c10.py
similarity index 100%
rename from gradio_demo/detectron2/export/c10.py
rename to detectron2/export/c10.py
diff --git a/gradio_demo/detectron2/export/caffe2_export.py b/detectron2/export/caffe2_export.py
similarity index 100%
rename from gradio_demo/detectron2/export/caffe2_export.py
rename to detectron2/export/caffe2_export.py
diff --git a/gradio_demo/detectron2/export/caffe2_inference.py b/detectron2/export/caffe2_inference.py
similarity index 100%
rename from gradio_demo/detectron2/export/caffe2_inference.py
rename to detectron2/export/caffe2_inference.py
diff --git a/gradio_demo/detectron2/export/caffe2_modeling.py b/detectron2/export/caffe2_modeling.py
similarity index 100%
rename from gradio_demo/detectron2/export/caffe2_modeling.py
rename to detectron2/export/caffe2_modeling.py
diff --git a/gradio_demo/detectron2/export/caffe2_patch.py b/detectron2/export/caffe2_patch.py
similarity index 100%
rename from gradio_demo/detectron2/export/caffe2_patch.py
rename to detectron2/export/caffe2_patch.py
diff --git a/gradio_demo/detectron2/export/flatten.py b/detectron2/export/flatten.py
similarity index 100%
rename from gradio_demo/detectron2/export/flatten.py
rename to detectron2/export/flatten.py
diff --git a/gradio_demo/detectron2/export/shared.py b/detectron2/export/shared.py
similarity index 100%
rename from gradio_demo/detectron2/export/shared.py
rename to detectron2/export/shared.py
diff --git a/gradio_demo/detectron2/export/torchscript.py b/detectron2/export/torchscript.py
similarity index 100%
rename from gradio_demo/detectron2/export/torchscript.py
rename to detectron2/export/torchscript.py
diff --git a/gradio_demo/detectron2/export/torchscript_patch.py b/detectron2/export/torchscript_patch.py
similarity index 100%
rename from gradio_demo/detectron2/export/torchscript_patch.py
rename to detectron2/export/torchscript_patch.py
diff --git a/gradio_demo/detectron2/layers/__init__.py b/detectron2/layers/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/layers/__init__.py
rename to detectron2/layers/__init__.py
diff --git a/gradio_demo/detectron2/layers/aspp.py b/detectron2/layers/aspp.py
similarity index 100%
rename from gradio_demo/detectron2/layers/aspp.py
rename to detectron2/layers/aspp.py
diff --git a/gradio_demo/detectron2/layers/batch_norm.py b/detectron2/layers/batch_norm.py
similarity index 100%
rename from gradio_demo/detectron2/layers/batch_norm.py
rename to detectron2/layers/batch_norm.py
diff --git a/gradio_demo/detectron2/layers/blocks.py b/detectron2/layers/blocks.py
similarity index 100%
rename from gradio_demo/detectron2/layers/blocks.py
rename to detectron2/layers/blocks.py
diff --git a/gradio_demo/detectron2/layers/csrc/README.md b/detectron2/layers/csrc/README.md
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/README.md
rename to detectron2/layers/csrc/README.md
diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
diff --git a/gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.cpp b/detectron2/layers/csrc/cocoeval/cocoeval.cpp
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.cpp
rename to detectron2/layers/csrc/cocoeval/cocoeval.cpp
diff --git a/gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.h b/detectron2/layers/csrc/cocoeval/cocoeval.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.h
rename to detectron2/layers/csrc/cocoeval/cocoeval.h
diff --git a/gradio_demo/detectron2/layers/csrc/cuda_version.cu b/detectron2/layers/csrc/cuda_version.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/cuda_version.cu
rename to detectron2/layers/csrc/cuda_version.cu
diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv.h b/detectron2/layers/csrc/deformable/deform_conv.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv.h
rename to detectron2/layers/csrc/deformable/deform_conv.h
diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda.cu b/detectron2/layers/csrc/deformable/deform_conv_cuda.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda.cu
rename to detectron2/layers/csrc/deformable/deform_conv_cuda.cu
diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu b/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
rename to detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu
diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated.h b/detectron2/layers/csrc/nms_rotated/nms_rotated.h
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated.h
rename to detectron2/layers/csrc/nms_rotated/nms_rotated.h
diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp b/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
rename to detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu b/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
rename to detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
diff --git a/gradio_demo/detectron2/layers/csrc/vision.cpp b/detectron2/layers/csrc/vision.cpp
similarity index 100%
rename from gradio_demo/detectron2/layers/csrc/vision.cpp
rename to detectron2/layers/csrc/vision.cpp
diff --git a/gradio_demo/detectron2/layers/deform_conv.py b/detectron2/layers/deform_conv.py
similarity index 100%
rename from gradio_demo/detectron2/layers/deform_conv.py
rename to detectron2/layers/deform_conv.py
diff --git a/gradio_demo/detectron2/layers/losses.py b/detectron2/layers/losses.py
similarity index 100%
rename from gradio_demo/detectron2/layers/losses.py
rename to detectron2/layers/losses.py
diff --git a/gradio_demo/detectron2/layers/mask_ops.py b/detectron2/layers/mask_ops.py
similarity index 100%
rename from gradio_demo/detectron2/layers/mask_ops.py
rename to detectron2/layers/mask_ops.py
diff --git a/gradio_demo/detectron2/layers/nms.py b/detectron2/layers/nms.py
similarity index 100%
rename from gradio_demo/detectron2/layers/nms.py
rename to detectron2/layers/nms.py
diff --git a/gradio_demo/detectron2/layers/roi_align.py b/detectron2/layers/roi_align.py
similarity index 100%
rename from gradio_demo/detectron2/layers/roi_align.py
rename to detectron2/layers/roi_align.py
diff --git a/gradio_demo/detectron2/layers/roi_align_rotated.py b/detectron2/layers/roi_align_rotated.py
similarity index 100%
rename from gradio_demo/detectron2/layers/roi_align_rotated.py
rename to detectron2/layers/roi_align_rotated.py
diff --git a/gradio_demo/detectron2/layers/rotated_boxes.py b/detectron2/layers/rotated_boxes.py
similarity index 100%
rename from gradio_demo/detectron2/layers/rotated_boxes.py
rename to detectron2/layers/rotated_boxes.py
diff --git a/gradio_demo/detectron2/layers/shape_spec.py b/detectron2/layers/shape_spec.py
similarity index 100%
rename from gradio_demo/detectron2/layers/shape_spec.py
rename to detectron2/layers/shape_spec.py
diff --git a/gradio_demo/detectron2/layers/wrappers.py b/detectron2/layers/wrappers.py
similarity index 100%
rename from gradio_demo/detectron2/layers/wrappers.py
rename to detectron2/layers/wrappers.py
diff --git a/gradio_demo/detectron2/model_zoo/__init__.py b/detectron2/model_zoo/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/model_zoo/__init__.py
rename to detectron2/model_zoo/__init__.py
diff --git a/gradio_demo/detectron2/model_zoo/model_zoo.py b/detectron2/model_zoo/model_zoo.py
similarity index 100%
rename from gradio_demo/detectron2/model_zoo/model_zoo.py
rename to detectron2/model_zoo/model_zoo.py
diff --git a/gradio_demo/detectron2/modeling/__init__.py b/detectron2/modeling/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/__init__.py
rename to detectron2/modeling/__init__.py
diff --git a/gradio_demo/detectron2/modeling/anchor_generator.py b/detectron2/modeling/anchor_generator.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/anchor_generator.py
rename to detectron2/modeling/anchor_generator.py
diff --git a/gradio_demo/detectron2/modeling/backbone/__init__.py b/detectron2/modeling/backbone/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/__init__.py
rename to detectron2/modeling/backbone/__init__.py
diff --git a/gradio_demo/detectron2/modeling/backbone/backbone.py b/detectron2/modeling/backbone/backbone.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/backbone.py
rename to detectron2/modeling/backbone/backbone.py
diff --git a/gradio_demo/detectron2/modeling/backbone/build.py b/detectron2/modeling/backbone/build.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/build.py
rename to detectron2/modeling/backbone/build.py
diff --git a/gradio_demo/detectron2/modeling/backbone/fpn.py b/detectron2/modeling/backbone/fpn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/fpn.py
rename to detectron2/modeling/backbone/fpn.py
diff --git a/gradio_demo/detectron2/modeling/backbone/mvit.py b/detectron2/modeling/backbone/mvit.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/mvit.py
rename to detectron2/modeling/backbone/mvit.py
diff --git a/gradio_demo/detectron2/modeling/backbone/regnet.py b/detectron2/modeling/backbone/regnet.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/regnet.py
rename to detectron2/modeling/backbone/regnet.py
diff --git a/gradio_demo/detectron2/modeling/backbone/resnet.py b/detectron2/modeling/backbone/resnet.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/resnet.py
rename to detectron2/modeling/backbone/resnet.py
diff --git a/gradio_demo/detectron2/modeling/backbone/swin.py b/detectron2/modeling/backbone/swin.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/swin.py
rename to detectron2/modeling/backbone/swin.py
diff --git a/gradio_demo/detectron2/modeling/backbone/utils.py b/detectron2/modeling/backbone/utils.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/utils.py
rename to detectron2/modeling/backbone/utils.py
diff --git a/gradio_demo/detectron2/modeling/backbone/vit.py b/detectron2/modeling/backbone/vit.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/backbone/vit.py
rename to detectron2/modeling/backbone/vit.py
diff --git a/gradio_demo/detectron2/modeling/box_regression.py b/detectron2/modeling/box_regression.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/box_regression.py
rename to detectron2/modeling/box_regression.py
diff --git a/gradio_demo/detectron2/modeling/matcher.py b/detectron2/modeling/matcher.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/matcher.py
rename to detectron2/modeling/matcher.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/__init__.py b/detectron2/modeling/meta_arch/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/__init__.py
rename to detectron2/modeling/meta_arch/__init__.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/build.py b/detectron2/modeling/meta_arch/build.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/build.py
rename to detectron2/modeling/meta_arch/build.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/dense_detector.py b/detectron2/modeling/meta_arch/dense_detector.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/dense_detector.py
rename to detectron2/modeling/meta_arch/dense_detector.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/fcos.py b/detectron2/modeling/meta_arch/fcos.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/fcos.py
rename to detectron2/modeling/meta_arch/fcos.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/panoptic_fpn.py b/detectron2/modeling/meta_arch/panoptic_fpn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/panoptic_fpn.py
rename to detectron2/modeling/meta_arch/panoptic_fpn.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/rcnn.py
rename to detectron2/modeling/meta_arch/rcnn.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/retinanet.py b/detectron2/modeling/meta_arch/retinanet.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/retinanet.py
rename to detectron2/modeling/meta_arch/retinanet.py
diff --git a/gradio_demo/detectron2/modeling/meta_arch/semantic_seg.py b/detectron2/modeling/meta_arch/semantic_seg.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/meta_arch/semantic_seg.py
rename to detectron2/modeling/meta_arch/semantic_seg.py
diff --git a/gradio_demo/detectron2/modeling/mmdet_wrapper.py b/detectron2/modeling/mmdet_wrapper.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/mmdet_wrapper.py
rename to detectron2/modeling/mmdet_wrapper.py
diff --git a/gradio_demo/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/poolers.py
rename to detectron2/modeling/poolers.py
diff --git a/gradio_demo/detectron2/modeling/postprocessing.py b/detectron2/modeling/postprocessing.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/postprocessing.py
rename to detectron2/modeling/postprocessing.py
diff --git a/gradio_demo/detectron2/modeling/proposal_generator/__init__.py b/detectron2/modeling/proposal_generator/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/proposal_generator/__init__.py
rename to detectron2/modeling/proposal_generator/__init__.py
diff --git a/gradio_demo/detectron2/modeling/proposal_generator/build.py b/detectron2/modeling/proposal_generator/build.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/proposal_generator/build.py
rename to detectron2/modeling/proposal_generator/build.py
diff --git a/gradio_demo/detectron2/modeling/proposal_generator/proposal_utils.py b/detectron2/modeling/proposal_generator/proposal_utils.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/proposal_generator/proposal_utils.py
rename to detectron2/modeling/proposal_generator/proposal_utils.py
diff --git a/gradio_demo/detectron2/modeling/proposal_generator/rpn.py b/detectron2/modeling/proposal_generator/rpn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/proposal_generator/rpn.py
rename to detectron2/modeling/proposal_generator/rpn.py
diff --git a/gradio_demo/detectron2/modeling/proposal_generator/rrpn.py b/detectron2/modeling/proposal_generator/rrpn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/proposal_generator/rrpn.py
rename to detectron2/modeling/proposal_generator/rrpn.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/__init__.py b/detectron2/modeling/roi_heads/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/__init__.py
rename to detectron2/modeling/roi_heads/__init__.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/box_head.py b/detectron2/modeling/roi_heads/box_head.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/box_head.py
rename to detectron2/modeling/roi_heads/box_head.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/cascade_rcnn.py b/detectron2/modeling/roi_heads/cascade_rcnn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/cascade_rcnn.py
rename to detectron2/modeling/roi_heads/cascade_rcnn.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/fast_rcnn.py
rename to detectron2/modeling/roi_heads/fast_rcnn.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/keypoint_head.py b/detectron2/modeling/roi_heads/keypoint_head.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/keypoint_head.py
rename to detectron2/modeling/roi_heads/keypoint_head.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/mask_head.py b/detectron2/modeling/roi_heads/mask_head.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/mask_head.py
rename to detectron2/modeling/roi_heads/mask_head.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/roi_heads.py b/detectron2/modeling/roi_heads/roi_heads.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/roi_heads.py
rename to detectron2/modeling/roi_heads/roi_heads.py
diff --git a/gradio_demo/detectron2/modeling/roi_heads/rotated_fast_rcnn.py b/detectron2/modeling/roi_heads/rotated_fast_rcnn.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/roi_heads/rotated_fast_rcnn.py
rename to detectron2/modeling/roi_heads/rotated_fast_rcnn.py
diff --git a/gradio_demo/detectron2/modeling/sampling.py b/detectron2/modeling/sampling.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/sampling.py
rename to detectron2/modeling/sampling.py
diff --git a/gradio_demo/detectron2/modeling/test_time_augmentation.py b/detectron2/modeling/test_time_augmentation.py
similarity index 100%
rename from gradio_demo/detectron2/modeling/test_time_augmentation.py
rename to detectron2/modeling/test_time_augmentation.py
diff --git a/gradio_demo/detectron2/projects/README.md b/detectron2/projects/README.md
similarity index 100%
rename from gradio_demo/detectron2/projects/README.md
rename to detectron2/projects/README.md
diff --git a/gradio_demo/detectron2/projects/__init__.py b/detectron2/projects/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/projects/__init__.py
rename to detectron2/projects/__init__.py
diff --git a/gradio_demo/detectron2/solver/__init__.py b/detectron2/solver/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/solver/__init__.py
rename to detectron2/solver/__init__.py
diff --git a/gradio_demo/detectron2/solver/build.py b/detectron2/solver/build.py
similarity index 100%
rename from gradio_demo/detectron2/solver/build.py
rename to detectron2/solver/build.py
diff --git a/gradio_demo/detectron2/solver/lr_scheduler.py b/detectron2/solver/lr_scheduler.py
similarity index 100%
rename from gradio_demo/detectron2/solver/lr_scheduler.py
rename to detectron2/solver/lr_scheduler.py
diff --git a/gradio_demo/detectron2/structures/__init__.py b/detectron2/structures/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/structures/__init__.py
rename to detectron2/structures/__init__.py
diff --git a/gradio_demo/detectron2/structures/boxes.py b/detectron2/structures/boxes.py
similarity index 100%
rename from gradio_demo/detectron2/structures/boxes.py
rename to detectron2/structures/boxes.py
diff --git a/gradio_demo/detectron2/structures/image_list.py b/detectron2/structures/image_list.py
similarity index 100%
rename from gradio_demo/detectron2/structures/image_list.py
rename to detectron2/structures/image_list.py
diff --git a/gradio_demo/detectron2/structures/instances.py b/detectron2/structures/instances.py
similarity index 100%
rename from gradio_demo/detectron2/structures/instances.py
rename to detectron2/structures/instances.py
diff --git a/gradio_demo/detectron2/structures/keypoints.py b/detectron2/structures/keypoints.py
similarity index 100%
rename from gradio_demo/detectron2/structures/keypoints.py
rename to detectron2/structures/keypoints.py
diff --git a/gradio_demo/detectron2/structures/masks.py b/detectron2/structures/masks.py
similarity index 100%
rename from gradio_demo/detectron2/structures/masks.py
rename to detectron2/structures/masks.py
diff --git a/gradio_demo/detectron2/structures/rotated_boxes.py b/detectron2/structures/rotated_boxes.py
similarity index 100%
rename from gradio_demo/detectron2/structures/rotated_boxes.py
rename to detectron2/structures/rotated_boxes.py
diff --git a/gradio_demo/detectron2/tracking/__init__.py b/detectron2/tracking/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/__init__.py
rename to detectron2/tracking/__init__.py
diff --git a/gradio_demo/detectron2/tracking/base_tracker.py b/detectron2/tracking/base_tracker.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/base_tracker.py
rename to detectron2/tracking/base_tracker.py
diff --git a/gradio_demo/detectron2/tracking/bbox_iou_tracker.py b/detectron2/tracking/bbox_iou_tracker.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/bbox_iou_tracker.py
rename to detectron2/tracking/bbox_iou_tracker.py
diff --git a/gradio_demo/detectron2/tracking/hungarian_tracker.py b/detectron2/tracking/hungarian_tracker.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/hungarian_tracker.py
rename to detectron2/tracking/hungarian_tracker.py
diff --git a/gradio_demo/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py b/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py
rename to detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py
diff --git a/gradio_demo/detectron2/tracking/utils.py b/detectron2/tracking/utils.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/utils.py
rename to detectron2/tracking/utils.py
diff --git a/gradio_demo/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py b/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py
similarity index 100%
rename from gradio_demo/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py
rename to detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py
diff --git a/gradio_demo/detectron2/utils/README.md b/detectron2/utils/README.md
similarity index 100%
rename from gradio_demo/detectron2/utils/README.md
rename to detectron2/utils/README.md
diff --git a/gradio_demo/detectron2/utils/__init__.py b/detectron2/utils/__init__.py
similarity index 100%
rename from gradio_demo/detectron2/utils/__init__.py
rename to detectron2/utils/__init__.py
diff --git a/gradio_demo/detectron2/utils/analysis.py b/detectron2/utils/analysis.py
similarity index 100%
rename from gradio_demo/detectron2/utils/analysis.py
rename to detectron2/utils/analysis.py
diff --git a/gradio_demo/detectron2/utils/collect_env.py b/detectron2/utils/collect_env.py
similarity index 100%
rename from gradio_demo/detectron2/utils/collect_env.py
rename to detectron2/utils/collect_env.py
diff --git a/gradio_demo/detectron2/utils/colormap.py b/detectron2/utils/colormap.py
similarity index 100%
rename from gradio_demo/detectron2/utils/colormap.py
rename to detectron2/utils/colormap.py
diff --git a/gradio_demo/detectron2/utils/comm.py b/detectron2/utils/comm.py
similarity index 100%
rename from gradio_demo/detectron2/utils/comm.py
rename to detectron2/utils/comm.py
diff --git a/gradio_demo/detectron2/utils/develop.py b/detectron2/utils/develop.py
similarity index 100%
rename from gradio_demo/detectron2/utils/develop.py
rename to detectron2/utils/develop.py
diff --git a/gradio_demo/detectron2/utils/env.py b/detectron2/utils/env.py
similarity index 100%
rename from gradio_demo/detectron2/utils/env.py
rename to detectron2/utils/env.py
diff --git a/gradio_demo/detectron2/utils/events.py b/detectron2/utils/events.py
similarity index 100%
rename from gradio_demo/detectron2/utils/events.py
rename to detectron2/utils/events.py
diff --git a/gradio_demo/detectron2/utils/file_io.py b/detectron2/utils/file_io.py
similarity index 100%
rename from gradio_demo/detectron2/utils/file_io.py
rename to detectron2/utils/file_io.py
diff --git a/gradio_demo/detectron2/utils/logger.py b/detectron2/utils/logger.py
similarity index 100%
rename from gradio_demo/detectron2/utils/logger.py
rename to detectron2/utils/logger.py
diff --git a/gradio_demo/detectron2/utils/memory.py b/detectron2/utils/memory.py
similarity index 100%
rename from gradio_demo/detectron2/utils/memory.py
rename to detectron2/utils/memory.py
diff --git a/gradio_demo/detectron2/utils/registry.py b/detectron2/utils/registry.py
similarity index 100%
rename from gradio_demo/detectron2/utils/registry.py
rename to detectron2/utils/registry.py
diff --git a/gradio_demo/detectron2/utils/serialize.py b/detectron2/utils/serialize.py
similarity index 100%
rename from gradio_demo/detectron2/utils/serialize.py
rename to detectron2/utils/serialize.py
diff --git a/gradio_demo/detectron2/utils/testing.py b/detectron2/utils/testing.py
similarity index 100%
rename from gradio_demo/detectron2/utils/testing.py
rename to detectron2/utils/testing.py
diff --git a/gradio_demo/detectron2/utils/tracing.py b/detectron2/utils/tracing.py
similarity index 100%
rename from gradio_demo/detectron2/utils/tracing.py
rename to detectron2/utils/tracing.py
diff --git a/gradio_demo/detectron2/utils/video_visualizer.py b/detectron2/utils/video_visualizer.py
similarity index 100%
rename from gradio_demo/detectron2/utils/video_visualizer.py
rename to detectron2/utils/video_visualizer.py
diff --git a/gradio_demo/detectron2/utils/visualizer.py b/detectron2/utils/visualizer.py
similarity index 100%
rename from gradio_demo/detectron2/utils/visualizer.py
rename to detectron2/utils/visualizer.py
diff --git a/environment.yaml b/environment.yaml
deleted file mode 100644
index 094229efacee8fb7b354eac7e4e340fe30b5fecc..0000000000000000000000000000000000000000
--- a/environment.yaml
+++ /dev/null
@@ -1,32 +0,0 @@
-name: idm
-channels:
- - pytorch
- - nvidia
- - defaults
-dependencies:
- - python=3.10.0=h12debd9_5
- - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0
- - pytorch-cuda=11.8=h7e8668a_5
- - torchaudio=2.0.2=py310_cu118
- - torchtriton=2.0.0=py310
- - torchvision=0.15.2=py310_cu118
- - pip=23.3.1=py310h06a4308_0
-
- - pip:
- - accelerate==0.25.0
- - torchmetrics==1.2.1
- - tqdm==4.66.1
- - transformers==4.36.2
- - diffusers==0.25.0
- - einops==0.7.0
- - bitsandbytes==0.39.0
- - scipy==1.11.1
- - opencv-python
- - gradio==4.24.0
- - fvcore
- - cloudpickle
- - omegaconf
- - pycocotools
- - basicsr
- - av
- - onnxruntime==1.16.2
diff --git a/gradio_demo/example/cloth/04469_00.jpg b/example/cloth/04469_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/04469_00.jpg
rename to example/cloth/04469_00.jpg
diff --git a/gradio_demo/example/cloth/04743_00.jpg b/example/cloth/04743_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/04743_00.jpg
rename to example/cloth/04743_00.jpg
diff --git a/gradio_demo/example/cloth/09133_00.jpg b/example/cloth/09133_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09133_00.jpg
rename to example/cloth/09133_00.jpg
diff --git a/gradio_demo/example/cloth/09163_00.jpg b/example/cloth/09163_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09163_00.jpg
rename to example/cloth/09163_00.jpg
diff --git a/gradio_demo/example/cloth/09164_00.jpg b/example/cloth/09164_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09164_00.jpg
rename to example/cloth/09164_00.jpg
diff --git a/gradio_demo/example/cloth/09166_00.jpg b/example/cloth/09166_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09166_00.jpg
rename to example/cloth/09166_00.jpg
diff --git a/gradio_demo/example/cloth/09176_00.jpg b/example/cloth/09176_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09176_00.jpg
rename to example/cloth/09176_00.jpg
diff --git a/gradio_demo/example/cloth/09236_00.jpg b/example/cloth/09236_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09236_00.jpg
rename to example/cloth/09236_00.jpg
diff --git a/gradio_demo/example/cloth/09256_00.jpg b/example/cloth/09256_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09256_00.jpg
rename to example/cloth/09256_00.jpg
diff --git a/gradio_demo/example/cloth/09263_00.jpg b/example/cloth/09263_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09263_00.jpg
rename to example/cloth/09263_00.jpg
diff --git a/gradio_demo/example/cloth/09266_00.jpg b/example/cloth/09266_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09266_00.jpg
rename to example/cloth/09266_00.jpg
diff --git a/gradio_demo/example/cloth/09290_00.jpg b/example/cloth/09290_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09290_00.jpg
rename to example/cloth/09290_00.jpg
diff --git a/gradio_demo/example/cloth/09305_00.jpg b/example/cloth/09305_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/09305_00.jpg
rename to example/cloth/09305_00.jpg
diff --git a/gradio_demo/example/cloth/10165_00.jpg b/example/cloth/10165_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/10165_00.jpg
rename to example/cloth/10165_00.jpg
diff --git a/gradio_demo/example/cloth/14627_00.jpg b/example/cloth/14627_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/14627_00.jpg
rename to example/cloth/14627_00.jpg
diff --git a/gradio_demo/example/cloth/14673_00.jpg b/example/cloth/14673_00.jpg
similarity index 100%
rename from gradio_demo/example/cloth/14673_00.jpg
rename to example/cloth/14673_00.jpg
diff --git a/gradio_demo/example/human/00034_00.jpg b/example/human/00034_00.jpg
similarity index 100%
rename from gradio_demo/example/human/00034_00.jpg
rename to example/human/00034_00.jpg
diff --git a/gradio_demo/example/human/00035_00.jpg b/example/human/00035_00.jpg
similarity index 100%
rename from gradio_demo/example/human/00035_00.jpg
rename to example/human/00035_00.jpg
diff --git a/gradio_demo/example/human/00055_00.jpg b/example/human/00055_00.jpg
similarity index 100%
rename from gradio_demo/example/human/00055_00.jpg
rename to example/human/00055_00.jpg
diff --git a/gradio_demo/example/human/00121_00.jpg b/example/human/00121_00.jpg
similarity index 100%
rename from gradio_demo/example/human/00121_00.jpg
rename to example/human/00121_00.jpg
diff --git a/gradio_demo/example/human/01992_00.jpg b/example/human/01992_00.jpg
similarity index 100%
rename from gradio_demo/example/human/01992_00.jpg
rename to example/human/01992_00.jpg
diff --git a/gradio_demo/example/human/Jensen.jpeg b/example/human/Jensen.jpeg
similarity index 100%
rename from gradio_demo/example/human/Jensen.jpeg
rename to example/human/Jensen.jpeg
diff --git a/gradio_demo/example/human/sam1 (1).jpg b/example/human/sam1 (1).jpg
similarity index 100%
rename from gradio_demo/example/human/sam1 (1).jpg
rename to example/human/sam1 (1).jpg
diff --git a/gradio_demo/example/human/taylor-.jpg b/example/human/taylor-.jpg
similarity index 100%
rename from gradio_demo/example/human/taylor-.jpg
rename to example/human/taylor-.jpg
diff --git a/gradio_demo/example/human/will1 (1).jpg b/example/human/will1 (1).jpg
similarity index 100%
rename from gradio_demo/example/human/will1 (1).jpg
rename to example/human/will1 (1).jpg
diff --git a/hf b/hf
deleted file mode 100644
index c3b9c41e2dc0597f8cea3f1a01519861900e95a9..0000000000000000000000000000000000000000
--- a/hf
+++ /dev/null
@@ -1,7 +0,0 @@
------BEGIN OPENSSH PRIVATE KEY-----
-b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
-QyNTUxOQAAACAnrBvCQqFPKBjg50yOVYY4CRZFOXYtf87WU1C1+176WwAAAJjRr9Rv0a/U
-bwAAAAtzc2gtZWQyNTUxOQAAACAnrBvCQqFPKBjg50yOVYY4CRZFOXYtf87WU1C1+176Ww
-AAAEDgMBd8z9N6Q1/05T0A5KiyVisusVzLpxDYpBMZEq6WfiesG8JCoU8oGODnTI5VhjgJ
-FkU5di1/ztZTULX7XvpbAAAAFWlxdGVjaC4yMDIyQGdtYWlsLmNvbQ==
------END OPENSSH PRIVATE KEY-----
diff --git a/hf.pub b/hf.pub
deleted file mode 100644
index e792b00fefbdc1e37c9c3587f2e4cd107611e373..0000000000000000000000000000000000000000
--- a/hf.pub
+++ /dev/null
@@ -1 +0,0 @@
-ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAICesG8JCoU8oGODnTI5VhjgJFkU5di1/ztZTULX7Xvpb iqtech.2022@gmail.com
diff --git a/inference.py b/inference.py
deleted file mode 100644
index 23e44ba8bd0f53dd0e3ed21bf6285e385a3c521e..0000000000000000000000000000000000000000
--- a/inference.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# coding=utf-8
-# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
-from ip_adapter.ip_adapter import Resampler
-
-import argparse
-import logging
-import os
-import torch.utils.data as data
-import torchvision
-import json
-import accelerate
-import numpy as np
-import torch
-from PIL import Image
-import torch.nn.functional as F
-import transformers
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import ProjectConfiguration, set_seed
-from packaging import version
-from torchvision import transforms
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline
-from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
-
-from diffusers.utils.import_utils import is_xformers_available
-
-from src.unet_hacked_tryon import UNet2DConditionModel
-from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
-from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
-
-
-
-logger = get_logger(__name__, log_level="INFO")
-
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,)
- parser.add_argument("--width",type=int,default=768,)
- parser.add_argument("--height",type=int,default=1024,)
- parser.add_argument("--num_inference_steps",type=int,default=30,)
- parser.add_argument("--output_dir",type=str,default="result",)
- parser.add_argument("--unpaired",action="store_true",)
- parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando")
- parser.add_argument("--seed", type=int, default=42,)
- parser.add_argument("--test_batch_size", type=int, default=2,)
- parser.add_argument("--guidance_scale",type=float,default=2.0,)
- parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],)
- parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
- args = parser.parse_args()
-
-
- return args
-
-def pil_to_tensor(images):
- images = np.array(images).astype(np.float32) / 255.0
- images = torch.from_numpy(images.transpose(2, 0, 1))
- return images
-
-
-class VitonHDTestDataset(data.Dataset):
- def __init__(
- self,
- dataroot_path: str,
- phase: Literal["train", "test"],
- order: Literal["paired", "unpaired"] = "paired",
- size: Tuple[int, int] = (512, 384),
- ):
- super(VitonHDTestDataset, self).__init__()
- self.dataroot = dataroot_path
- self.phase = phase
- self.height = size[0]
- self.width = size[1]
- self.size = size
- self.transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
- self.toTensor = transforms.ToTensor()
-
- with open(
- os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
- ) as file1:
- data1 = json.load(file1)
-
- annotation_list = [
- "sleeveLength",
- "neckLine",
- "item",
- ]
-
- self.annotation_pair = {}
- for k, v in data1.items():
- for elem in v:
- annotation_str = ""
- for template in annotation_list:
- for tag in elem["tag_info"]:
- if (
- tag["tag_name"] == template
- and tag["tag_category"] is not None
- ):
- annotation_str += tag["tag_category"]
- annotation_str += " "
- self.annotation_pair[elem["file_name"]] = annotation_str
-
- self.order = order
- self.toTensor = transforms.ToTensor()
-
- im_names = []
- c_names = []
- dataroot_names = []
-
-
- if phase == "train":
- filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
- else:
- filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
-
- with open(filename, "r") as f:
- for line in f.readlines():
- if phase == "train":
- im_name, _ = line.strip().split()
- c_name = im_name
- else:
- if order == "paired":
- im_name, _ = line.strip().split()
- c_name = im_name
- else:
- im_name, c_name = line.strip().split()
-
- im_names.append(im_name)
- c_names.append(c_name)
- dataroot_names.append(dataroot_path)
-
- self.im_names = im_names
- self.c_names = c_names
- self.dataroot_names = dataroot_names
- self.clip_processor = CLIPImageProcessor()
- def __getitem__(self, index):
- c_name = self.c_names[index]
- im_name = self.im_names[index]
- if c_name in self.annotation_pair:
- cloth_annotation = self.annotation_pair[c_name]
- else:
- cloth_annotation = "shirts"
- cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))
-
- im_pil_big = Image.open(
- os.path.join(self.dataroot, self.phase, "image", im_name)
- ).resize((self.width,self.height))
- image = self.transform(im_pil_big)
-
- mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height))
- mask = self.toTensor(mask)
- mask = mask[:1]
- mask = 1-mask
- im_mask = image * mask
-
- pose_img = Image.open(
- os.path.join(self.dataroot, self.phase, "image-densepose", im_name)
- )
- pose_img = self.transform(pose_img) # [-1,1]
-
- result = {}
- result["c_name"] = c_name
- result["im_name"] = im_name
- result["image"] = image
- result["cloth_pure"] = self.transform(cloth)
- result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
- result["inpaint_mask"] =1-mask
- result["im_mask"] = im_mask
- result["caption_cloth"] = "a photo of " + cloth_annotation
- result["caption"] = "model is wearing a " + cloth_annotation
- result["pose_img"] = pose_img
-
- return result
-
- def __len__(self):
- # model images + cloth image
- return len(self.im_names)
-
-
-
-
-def main():
- args = parse_args()
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
- accelerator = Accelerator(
- mixed_precision=args.mixed_precision,
- project_config=accelerator_project_config,
- )
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- weight_dtype = torch.float16
- # if accelerator.mixed_precision == "fp16":
- # weight_dtype = torch.float16
- # args.mixed_precision = accelerator.mixed_precision
- # elif accelerator.mixed_precision == "bf16":
- # weight_dtype = torch.bfloat16
- # args.mixed_precision = accelerator.mixed_precision
-
- # Load scheduler, tokenizer and models.
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="vae",
- torch_dtype=torch.float16,
- )
- unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="unet",
- torch_dtype=torch.float16,
- )
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="image_encoder",
- torch_dtype=torch.float16,
- )
- unet_encoder = UNet2DConditionModel_ref.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="unet_encoder",
- torch_dtype=torch.float16,
- )
- text_encoder_one = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder",
- torch_dtype=torch.float16,
- )
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder_2",
- torch_dtype=torch.float16,
- )
- tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=None,
- use_fast=False,
- )
- tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=None,
- use_fast=False,
- )
-
-
- # Freeze vae and text_encoder and set unet to trainable
- unet.requires_grad_(False)
- vae.requires_grad_(False)
- image_encoder.requires_grad_(False)
- unet_encoder.requires_grad_(False)
- text_encoder_one.requires_grad_(False)
- text_encoder_two.requires_grad_(False)
- unet_encoder.to(accelerator.device, weight_dtype)
- unet.eval()
- unet_encoder.eval()
-
-
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
-
- xformers_version = version.parse(xformers.__version__)
- if xformers_version == version.parse("0.0.16"):
- logger.warn(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
- )
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
-
- test_dataset = VitonHDTestDataset(
- dataroot_path=args.data_dir,
- phase="test",
- order="unpaired" if args.unpaired else "paired",
- size=(args.height, args.width),
- )
- test_dataloader = torch.utils.data.DataLoader(
- test_dataset,
- shuffle=False,
- batch_size=args.test_batch_size,
- num_workers=4,
- )
-
- pipe = TryonPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- vae=vae,
- feature_extractor= CLIPImageProcessor(),
- text_encoder = text_encoder_one,
- text_encoder_2 = text_encoder_two,
- tokenizer = tokenizer_one,
- tokenizer_2 = tokenizer_two,
- scheduler = noise_scheduler,
- image_encoder=image_encoder,
- unet_encoder = unet_encoder,
- torch_dtype=torch.float16,
- ).to(accelerator.device)
-
- # pipe.enable_sequential_cpu_offload()
- # pipe.enable_model_cpu_offload()
- # pipe.enable_vae_slicing()
-
-
-
- with torch.no_grad():
- # Extract the images
- with torch.cuda.amp.autocast():
- with torch.no_grad():
- for sample in test_dataloader:
- img_emb_list = []
- for i in range(sample['cloth'].shape[0]):
- img_emb_list.append(sample['cloth'][i])
-
- prompt = sample["caption"]
-
- num_prompts = sample['cloth'].shape[0]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_embeds = torch.cat(img_emb_list,dim=0)
-
- with torch.inference_mode():
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
-
-
- prompt = sample["caption_cloth"]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
-
- with torch.inference_mode():
- (
- prompt_embeds_c,
- _,
- _,
- _,
- ) = pipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=False,
- negative_prompt=negative_prompt,
- )
-
-
-
- generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None
- images = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=args.num_inference_steps,
- generator=generator,
- strength = 1.0,
- pose_img = sample['pose_img'],
- text_embeds_cloth=prompt_embeds_c,
- cloth = sample["cloth_pure"].to(accelerator.device),
- mask_image=sample['inpaint_mask'],
- image=(sample['image']+1.0)/2.0,
- height=args.height,
- width=args.width,
- guidance_scale=args.guidance_scale,
- ip_adapter_image = image_embeds,
- )[0]
-
-
- for i in range(len(images)):
- x_sample = pil_to_tensor(images[i])
- torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i]))
-
-
-
-
-if __name__ == "__main__":
- main()
diff --git a/inference.sh b/inference.sh
deleted file mode 100644
index 89d6747bcee011997a59f070dfba03b51958cad3..0000000000000000000000000000000000000000
--- a/inference.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#VITON-HD
-##paired setting
-accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
- --width 768 --height 1024 --num_inference_steps 30 \
- --output_dir "result" --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \
- --seed 42 --test_batch_size 2 --guidance_scale 2.0
-
-
-##unpaired setting
-accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
- --width 768 --height 1024 --num_inference_steps 30 \
- --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \
- --seed 42 --test_batch_size 2 --guidance_scale 2.0
-
-
-
-#DressCode
-##upper_body
-accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
- --width 768 --height 1024 --num_inference_steps 30 \
- --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
- --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "upper_body"
-
-##lower_body
-accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
- --width 768 --height 1024 --num_inference_steps 30 \
- --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
- --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "lower_body"
-
-##dresses
-accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \
- --width 768 --height 1024 --num_inference_steps 30 \
- --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \
- --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "dresses"
\ No newline at end of file
diff --git a/inference_dc.py b/inference_dc.py
deleted file mode 100644
index fd8e78cd6a206b554189068f95cad3d2dd0eab34..0000000000000000000000000000000000000000
--- a/inference_dc.py
+++ /dev/null
@@ -1,578 +0,0 @@
-# coding=utf-8
-# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
-from ip_adapter.ip_adapter import Resampler
-
-import argparse
-import logging
-import os
-import torch.utils.data as data
-import torchvision
-import json
-import accelerate
-import numpy as np
-import torch
-from PIL import Image, ImageDraw
-import torch.nn.functional as F
-import transformers
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import ProjectConfiguration, set_seed
-from packaging import version
-from torchvision import transforms
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline
-from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer
-import cv2
-from diffusers.utils.import_utils import is_xformers_available
-from numpy.linalg import lstsq
-
-from src.unet_hacked_tryon import UNet2DConditionModel
-from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
-from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
-
-
-
-logger = get_logger(__name__, log_level="INFO")
-
-label_map={
- "background": 0,
- "hat": 1,
- "hair": 2,
- "sunglasses": 3,
- "upper_clothes": 4,
- "skirt": 5,
- "pants": 6,
- "dress": 7,
- "belt": 8,
- "left_shoe": 9,
- "right_shoe": 10,
- "head": 11,
- "left_leg": 12,
- "right_leg": 13,
- "left_arm": 14,
- "right_arm": 15,
- "bag": 16,
- "scarf": 17,
-}
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,)
- parser.add_argument("--width",type=int,default=768,)
- parser.add_argument("--height",type=int,default=1024,)
- parser.add_argument("--num_inference_steps",type=int,default=30,)
- parser.add_argument("--output_dir",type=str,default="result",)
- parser.add_argument("--category",type=str,default="upper_body",choices=["upper_body", "lower_body", "dresses"])
- parser.add_argument("--unpaired",action="store_true",)
- parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando")
- parser.add_argument("--seed", type=int, default=42,)
- parser.add_argument("--test_batch_size", type=int, default=2,)
- parser.add_argument("--guidance_scale",type=float,default=2.0,)
- parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],)
- parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
- args = parser.parse_args()
-
-
- return args
-
-def pil_to_tensor(images):
- images = np.array(images).astype(np.float32) / 255.0
- images = torch.from_numpy(images.transpose(2, 0, 1))
- return images
-
-
-class DresscodeTestDataset(data.Dataset):
- def __init__(
- self,
- dataroot_path: str,
- phase: Literal["train", "test"],
- order: Literal["paired", "unpaired"] = "paired",
- category = "upper_body",
- size: Tuple[int, int] = (512, 384),
- ):
- super(DresscodeTestDataset, self).__init__()
- self.dataroot = os.path.join(dataroot_path,category)
- self.phase = phase
- self.height = size[0]
- self.width = size[1]
- self.size = size
- self.transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
- self.toTensor = transforms.ToTensor()
- self.order = order
- self.radius = 5
- self.category = category
- im_names = []
- c_names = []
-
-
- if phase == "train":
- filename = os.path.join(dataroot_path,category, f"{phase}_pairs.txt")
- else:
- filename = os.path.join(dataroot_path,category, f"{phase}_pairs_{order}.txt")
-
- with open(filename, "r") as f:
- for line in f.readlines():
- im_name, c_name = line.strip().split()
-
- im_names.append(im_name)
- c_names.append(c_name)
-
-
- file_path = os.path.join(dataroot_path,category,"dc_caption.txt")
-
- self.annotation_pair = {}
- with open(file_path, "r") as file:
- for line in file:
- parts = line.strip().split(" ")
- self.annotation_pair[parts[0]] = ' '.join(parts[1:])
-
-
- self.im_names = im_names
- self.c_names = c_names
- self.clip_processor = CLIPImageProcessor()
- def __getitem__(self, index):
- c_name = self.c_names[index]
- im_name = self.im_names[index]
- if c_name in self.annotation_pair:
- cloth_annotation = self.annotation_pair[c_name]
- else:
- cloth_annotation = self.category
- cloth = Image.open(os.path.join(self.dataroot, "images", c_name))
-
- im_pil_big = Image.open(
- os.path.join(self.dataroot, "images", im_name)
- ).resize((self.width,self.height))
- image = self.transform(im_pil_big)
-
-
-
-
- skeleton = Image.open(os.path.join(self.dataroot, 'skeletons', im_name.replace("_0", "_5")))
- skeleton = skeleton.resize((self.width, self.height))
- skeleton = self.transform(skeleton)
-
- # Label Map
- parse_name = im_name.replace('_0.jpg', '_4.png')
- im_parse = Image.open(os.path.join(self.dataroot, 'label_maps', parse_name))
- im_parse = im_parse.resize((self.width, self.height), Image.NEAREST)
- parse_array = np.array(im_parse)
-
- # Load pose points
- pose_name = im_name.replace('_0.jpg', '_2.json')
- with open(os.path.join(self.dataroot, 'keypoints', pose_name), 'r') as f:
- pose_label = json.load(f)
- pose_data = pose_label['keypoints']
- pose_data = np.array(pose_data)
- pose_data = pose_data.reshape((-1, 4))
-
- point_num = pose_data.shape[0]
- pose_map = torch.zeros(point_num, self.height, self.width)
- r = self.radius * (self.height / 512.0)
- for i in range(point_num):
- one_map = Image.new('L', (self.width, self.height))
- draw = ImageDraw.Draw(one_map)
- point_x = np.multiply(pose_data[i, 0], self.width / 384.0)
- point_y = np.multiply(pose_data[i, 1], self.height / 512.0)
- if point_x > 1 and point_y > 1:
- draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white')
- one_map = self.toTensor(one_map)
- pose_map[i] = one_map[0]
-
- agnostic_mask = self.get_agnostic(parse_array, pose_data, self.category, (self.width,self.height))
- # agnostic_mask = transforms.functional.resize(agnostic_mask, (self.height, self.width),
- # interpolation=transforms.InterpolationMode.NEAREST)
-
- mask = 1 - agnostic_mask
- im_mask = image * agnostic_mask
-
- pose_img = Image.open(
- os.path.join(self.dataroot, "image-densepose", im_name)
- )
- pose_img = self.transform(pose_img) # [-1,1]
-
- result = {}
- result["c_name"] = c_name
- result["im_name"] = im_name
- result["image"] = image
- result["cloth_pure"] = self.transform(cloth)
- result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
- result["inpaint_mask"] =mask
- result["im_mask"] = im_mask
- result["caption_cloth"] = "a photo of " + cloth_annotation
- result["caption"] = "model is wearing a " + cloth_annotation
- result["pose_img"] = pose_img
-
- return result
-
- def __len__(self):
- # model images + cloth image
- return len(self.im_names)
-
-
-
-
- def get_agnostic(self,parse_array, pose_data, category, size):
- parse_shape = (parse_array > 0).astype(np.float32)
-
- parse_head = (parse_array == 1).astype(np.float32) + \
- (parse_array == 2).astype(np.float32) + \
- (parse_array == 3).astype(np.float32) + \
- (parse_array == 11).astype(np.float32)
-
- parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \
- (parse_array == label_map["left_shoe"]).astype(np.float32) + \
- (parse_array == label_map["right_shoe"]).astype(np.float32) + \
- (parse_array == label_map["hat"]).astype(np.float32) + \
- (parse_array == label_map["sunglasses"]).astype(np.float32) + \
- (parse_array == label_map["scarf"]).astype(np.float32) + \
- (parse_array == label_map["bag"]).astype(np.float32)
-
- parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
-
- arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32)
-
- if category == 'dresses':
- label_cat = 7
- parse_mask = (parse_array == 7).astype(np.float32) + \
- (parse_array == 12).astype(np.float32) + \
- (parse_array == 13).astype(np.float32)
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
-
- elif category == 'upper_body':
- label_cat = 4
- parse_mask = (parse_array == 4).astype(np.float32)
-
- parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \
- (parse_array == label_map["pants"]).astype(np.float32)
-
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
- elif category == 'lower_body':
- label_cat = 6
- parse_mask = (parse_array == 6).astype(np.float32) + \
- (parse_array == 12).astype(np.float32) + \
- (parse_array == 13).astype(np.float32)
-
- parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
- (parse_array == 14).astype(np.float32) + \
- (parse_array == 15).astype(np.float32)
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
-
- parse_head = torch.from_numpy(parse_head) # [0,1]
- parse_mask = torch.from_numpy(parse_mask) # [0,1]
- parser_mask_fixed = torch.from_numpy(parser_mask_fixed)
- parser_mask_changeable = torch.from_numpy(parser_mask_changeable)
-
- # dilation
- parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask))
- parse_mask = parse_mask.cpu().numpy()
-
- width = size[0]
- height = size[1]
-
- im_arms = Image.new('L', (width, height))
- arms_draw = ImageDraw.Draw(im_arms)
- if category == 'dresses' or category == 'upper_body':
- shoulder_right = tuple(np.multiply(pose_data[2, :2], height / 512.0))
- shoulder_left = tuple(np.multiply(pose_data[5, :2], height / 512.0))
- elbow_right = tuple(np.multiply(pose_data[3, :2], height / 512.0))
- elbow_left = tuple(np.multiply(pose_data[6, :2], height / 512.0))
- wrist_right = tuple(np.multiply(pose_data[4, :2], height / 512.0))
- wrist_left = tuple(np.multiply(pose_data[7, :2], height / 512.0))
- if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
- if elbow_right[0] <= 1. and elbow_right[1] <= 1.:
- arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right], 'white', 30, 'curve')
- else:
- arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right], 'white', 30,
- 'curve')
- elif wrist_left[0] <= 1. and wrist_left[1] <= 1.:
- if elbow_left[0] <= 1. and elbow_left[1] <= 1.:
- arms_draw.line([shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30, 'curve')
- else:
- arms_draw.line([elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30,
- 'curve')
- else:
- arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white',
- 30, 'curve')
-
- if height > 512:
- im_arms = cv2.dilate(np.float32(im_arms), np.ones((10, 10), np.uint16), iterations=5)
- elif height > 256:
- im_arms = cv2.dilate(np.float32(im_arms), np.ones((5, 5), np.uint16), iterations=5)
- hands = np.logical_and(np.logical_not(im_arms), arms)
- parse_mask += im_arms
- parser_mask_fixed += hands
-
- # delete neck
- parse_head_2 = torch.clone(parse_head)
- if category == 'dresses' or category == 'upper_body':
- points = []
- points.append(np.multiply(pose_data[2, :2], height / 512.0))
- points.append(np.multiply(pose_data[5, :2], height / 512.0))
- x_coords, y_coords = zip(*points)
- A = np.vstack([x_coords, np.ones(len(x_coords))]).T
- m, c = lstsq(A, y_coords, rcond=None)[0]
- for i in range(parse_array.shape[1]):
- y = i * m + c
- parse_head_2[int(y - 20 * (height / 512.0)):, i] = 0
-
- parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16))
- parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16),
- np.logical_not(np.array(parse_head_2, dtype=np.uint16))))
-
- if height > 512:
- parse_mask = cv2.dilate(parse_mask, np.ones((20, 20), np.uint16), iterations=5)
- elif height > 256:
- parse_mask = cv2.dilate(parse_mask, np.ones((10, 10), np.uint16), iterations=5)
- else:
- parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
- parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
- parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
- agnostic_mask = parse_mask_total.unsqueeze(0)
- return agnostic_mask
-
-
-
-
-def main():
- args = parse_args()
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
- accelerator = Accelerator(
- mixed_precision=args.mixed_precision,
- project_config=accelerator_project_config,
- )
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- weight_dtype = torch.float16
- # if accelerator.mixed_precision == "fp16":
- # weight_dtype = torch.float16
- # args.mixed_precision = accelerator.mixed_precision
- # elif accelerator.mixed_precision == "bf16":
- # weight_dtype = torch.bfloat16
- # args.mixed_precision = accelerator.mixed_precision
-
- # Load scheduler, tokenizer and models.
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="vae",
- torch_dtype=torch.float16,
- )
- unet = UNet2DConditionModel.from_pretrained(
- "yisol/IDM-VTON-DC",
- subfolder="unet",
- torch_dtype=torch.float16,
- )
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="image_encoder",
- torch_dtype=torch.float16,
- )
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="unet_encoder",
- torch_dtype=torch.float16,
- )
- text_encoder_one = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder",
- torch_dtype=torch.float16,
- )
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder_2",
- torch_dtype=torch.float16,
- )
- tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=None,
- use_fast=False,
- )
- tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=None,
- use_fast=False,
- )
-
-
- # Freeze vae and text_encoder and set unet to trainable
- unet.requires_grad_(False)
- vae.requires_grad_(False)
- image_encoder.requires_grad_(False)
- UNet_Encoder.requires_grad_(False)
- text_encoder_one.requires_grad_(False)
- text_encoder_two.requires_grad_(False)
- UNet_Encoder.to(accelerator.device, weight_dtype)
- unet.eval()
- UNet_Encoder.eval()
-
-
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
-
- xformers_version = version.parse(xformers.__version__)
- if xformers_version == version.parse("0.0.16"):
- logger.warn(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
- )
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
-
- test_dataset = DresscodeTestDataset(
- dataroot_path=args.data_dir,
- phase="test",
- order="unpaired" if args.unpaired else "paired",
- category = args.category,
- size=(args.height, args.width),
- )
- test_dataloader = torch.utils.data.DataLoader(
- test_dataset,
- shuffle=False,
- batch_size=args.test_batch_size,
- num_workers=4,
- )
-
- pipe = TryonPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- vae=vae,
- feature_extractor= CLIPImageProcessor(),
- text_encoder = text_encoder_one,
- text_encoder_2 = text_encoder_two,
- tokenizer = tokenizer_one,
- tokenizer_2 = tokenizer_two,
- scheduler = noise_scheduler,
- image_encoder=image_encoder,
- torch_dtype=torch.float16,
- ).to(accelerator.device)
- pipe.unet_encoder = UNet_Encoder
-
- # pipe.enable_sequential_cpu_offload()
- # pipe.enable_model_cpu_offload()
- # pipe.enable_vae_slicing()
-
-
-
- with torch.no_grad():
- # Extract the images
- with torch.cuda.amp.autocast():
- with torch.no_grad():
- for sample in test_dataloader:
- img_emb_list = []
- for i in range(sample['cloth'].shape[0]):
- img_emb_list.append(sample['cloth'][i])
-
- prompt = sample["caption"]
-
- num_prompts = sample['cloth'].shape[0]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_embeds = torch.cat(img_emb_list,dim=0)
-
- with torch.inference_mode():
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = pipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
-
-
- prompt = sample["caption_cloth"]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
-
- with torch.inference_mode():
- (
- prompt_embeds_c,
- _,
- _,
- _,
- ) = pipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=False,
- negative_prompt=negative_prompt,
- )
-
-
-
- generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None
- images = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=args.num_inference_steps,
- generator=generator,
- strength = 1.0,
- pose_img = sample['pose_img'],
- text_embeds_cloth=prompt_embeds_c,
- cloth = sample["cloth_pure"].to(accelerator.device),
- mask_image=sample['inpaint_mask'],
- image=(sample['image']+1.0)/2.0,
- height=args.height,
- width=args.width,
- guidance_scale=args.guidance_scale,
- ip_adapter_image = image_embeds,
- )[0]
-
-
- for i in range(len(images)):
- x_sample = pil_to_tensor(images[i])
- torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i]))
-
-
-
-
-if __name__ == "__main__":
- main()
diff --git a/ip_adapter/__init__.py b/ip_adapter/__init__.py
index b275952105f50616770a83609ee1eada68bffd90..3b1f1ff4e54e93ada7e85abc0f6687c5ecd3a338 100644
--- a/ip_adapter/__init__.py
+++ b/ip_adapter/__init__.py
@@ -1,4 +1,4 @@
-from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterPlus_Lora,IPAdapterPlus_Lora_up
+from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
__all__ = [
"IPAdapter",
@@ -6,6 +6,4 @@ __all__ = [
"IPAdapterPlusXL",
"IPAdapterXL",
"IPAdapterFull",
- "IPAdapterPlus_Lora",
- 'IPAdapterPlus_Lora_up',
]
diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py
index b6c40c41570a0ce05a4a373645747f3ee29275e7..07fd6d8e2fc622b1e5a2bdd52119566b8b5219ef 100644
--- a/ip_adapter/attention_processor.py
+++ b/ip_adapter/attention_processor.py
@@ -2,11 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from diffusers.models.lora import LoRACompatibleLinear
-from diffusers.models.lora import LoRALinearLayer,LoRAConv2dLayer
-from einops import rearrange
-from diffusers.models.transformer_2d import Transformer2DModel
class AttnProcessor(nn.Module):
r"""
@@ -97,1673 +93,6 @@ class IPAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AttnProcessor2_0(torch.nn.Module):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(
- self,
- hidden_size=None,
- cross_attention_dim=None,
- ):
- super().__init__()
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale= 1.0,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- # args = (scale, )
- args = ()
- query = attn.to_q(hidden_states, *args)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states, *args)
- value = attn.to_v(encoder_hidden_states, *args)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AttnProcessor2_0_attn(torch.nn.Module):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(
- self,
- hidden_size=None,
- cross_attention_dim=None,
- ):
- super().__init__()
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- is_cloth_pass=False,
- cloth = None,
- up_cnt=None,
- mid_cnt=None,
- down_cnt=None,
- inside_up=None,
- inside_down=None,
- cloth_text=None,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class AttnProcessor2_0_Lora(torch.nn.Module):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(
- self,
- scale_lora =1.0,
- hidden_size=None,
- cross_attention_dim=None,
- ):
- super().__init__()
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
- self.scale_lora = scale_lora
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- if hasattr(attn,'q_lora'):
- query = attn.to_q(hidden_states)
- q_lora = attn.q_lora(hidden_states)
- query = query + self.scale_lora * q_lora
- else:
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
-
- if hasattr(attn,'k_lora'):
- key = attn.to_k(hidden_states)
- k_lora = attn.k_lora(hidden_states)
- key = key + self.scale_lora * k_lora
- else:
- key = attn.to_k(hidden_states)
-
- if hasattr(attn,'v_lora'):
- value = attn.to_v(encoder_hidden_states)
- v_lora = attn.v_lora(hidden_states)
- value = value + self.scale_lora * v_lora
- else:
- value = attn.to_v(encoder_hidden_states)
-
-
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
-
-
- if hasattr(attn,'out_lora'):
- hidden_states = attn.to_out[0](hidden_states)
- out_lora = attn.out_lora(hidden_states)
- hidden_states = hidden_states+ self.scale_lora*out_lora
- else:
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class IPAttnProcessor_clothpass_noip(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
-
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- is_cloth_pass=False,
- cloth = None,
- up_cnt=None,
- mid_cnt=None,
- down_cnt=None,
- inside=None,
- ):
-
- if is_cloth_pass or up_cnt is None:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
- else:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- if attention_mask is not None:
- print('!!!!attention_mask is not NoNE!!!!')
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- # print(up_cnt*3 + inside)
- cloth_feature = cloth[up_cnt*3 + inside-1]
- cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous()
- # print(cloth_feature.shape)
- # print(self.hidden_size)
- c_key = self.to_k_c(cloth_feature)
- c_value = self.to_v_c(cloth_feature)
-
-
- c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # print(ip_value.shape)
- #$$ attn_mask?
- hidden_states_cloth = F.scaled_dot_product_attention(
- query, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states_cloth = hidden_states_cloth.to(query.dtype)
-
- hidden_states = hidden_states + self.scale * hidden_states_cloth
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
- return hidden_states
-
-
-class IPAttnProcessor_clothpass(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
-
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- is_cloth_pass=False,
- cloth = None,
- up_cnt=None,
- mid_cnt=None,
- down_cnt=None,
- inside=None,
- ):
-
- if is_cloth_pass :
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
- elif up_cnt is None:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
- else:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- if attention_mask is not None:
- print('!!!!attention_mask is not NoNE!!!!')
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- # print(up_cnt*3 + inside)
- cloth_feature = cloth[up_cnt*3 + inside-1]
- cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous()
-
-
-
-
-
-
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
-
- c_key = self.to_k_c(cloth_feature)
- c_value = self.to_v_c(cloth_feature)
-
- c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
- ip_hidden_states = F.scaled_dot_product_attention(
- ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
- return hidden_states
-
-
-
-
-
-
-class IPAttnProcessor_clothpass_extend(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
-
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- is_cloth_pass=False,
- cloth = None,
- up_cnt=None,
- mid_cnt=None,
- down_cnt=None,
- inside_up=None,
- inside_down=None,
- ):
- if is_cloth_pass :
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
- # elif up_cnt is None or down_cnt is None:
- # residual = hidden_states
-
- # if attn.spatial_norm is not None:
- # hidden_states = attn.spatial_norm(hidden_states, temb)
-
- # input_ndim = hidden_states.ndim
- # if input_ndim == 4:
- # batch_size, channel, height, width = hidden_states.shape
- # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- # batch_size, sequence_length, _ = (
- # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- # )
-
- # if attention_mask is not None:
- # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # # scaled_dot_product_attention expects attention_mask shape to be
- # # (batch, heads, source_length, target_length)
- # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- # if attn.group_norm is not None:
- # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- # query = attn.to_q(hidden_states)
- # if encoder_hidden_states is None:
- # encoder_hidden_states = hidden_states
- # else:
- # # get encoder_hidden_states, ip_hidden_states
- # end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- # encoder_hidden_states, ip_hidden_states = (
- # encoder_hidden_states[:, :end_pos, :],
- # encoder_hidden_states[:, end_pos:, :],
- # )
- # if attn.norm_cross:
- # encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- # key = attn.to_k(encoder_hidden_states)
- # value = attn.to_v(encoder_hidden_states)
-
- # inner_dim = key.shape[-1]
- # head_dim = inner_dim // attn.heads
-
- # query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # # TODO: add support for attn.scale when we move to Torch 2.1
- # hidden_states = F.scaled_dot_product_attention(
- # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- # )
-
- # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- # hidden_states = hidden_states.to(query.dtype)
-
- # # for ip-adapter
- # ip_key = self.to_k_ip(ip_hidden_states)
- # ip_value = self.to_v_ip(ip_hidden_states)
-
- # ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # # TODO: add support for attn.scale when we move to Torch 2.1
- # ip_hidden_states = F.scaled_dot_product_attention(
- # query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- # )
-
- # ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- # ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- # hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # # linear proj
- # hidden_states = attn.to_out[0](hidden_states)
- # # dropout
- # hidden_states = attn.to_out[1](hidden_states)
-
- # if input_ndim == 4:
- # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- # if attn.residual_connection:
- # hidden_states = hidden_states + residual
-
- # hidden_states = hidden_states / attn.rescale_output_factor
-
- # return hidden_states
- elif down_cnt is not None or up_cnt is not None or mid_cnt is not None:
- residual = hidden_states
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- print('!!!!attention_mask is not NoNE!!!!')
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # # for ip-adapter
- # print(self.hidden_size)
- # # print(up_cnt*3 + inside)
- # print(inside_down)
- cloth_feature = cloth[inside_down]
- # print(cloth_feature.shape)
- # if down_cnt is not None:
- # # print("up_index")
- # cloth_feature = cloth[down_cnt*3 + inside_down+1]
- # # print(up_cnt*3 + inside_up)
- # elif mid_cnt is not None:
- # cloth_feature = cloth[9]
- # else:
- # cloth_feature = cloth[11+up_cnt*3 + inside_up]
- # print("down_index")
- # print(down_cnt*3 + inside_down)
-
- cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous()
- # print(cloth_feature.shape)
- # print(self.hidden_size)
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
-
- c_key = self.to_k_c(cloth_feature)
- c_value = self.to_v_c(cloth_feature)
-
- c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
- ip_hidden_states = F.scaled_dot_product_attention(
- ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
- return hidden_states
- else:
- assert(False)
-
-
-
-class IPAttnProcessorMulti2_0_2(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
-
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- is_cloth_pass=False,
- cloth = None,
- up_cnt=None,
- mid_cnt=None,
- down_cnt=None,
- inside=None,
- cloth_text=None,
- ):
-
- if is_cloth_pass or up_cnt is None:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
- else:
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- # print(up_cnt)
- # print("hidden_states.shape")
- # print(hidden_states.shape)
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- # print(up_cnt*3 + inside)
- cloth_feature = cloth[up_cnt*3 + inside-1]
- cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous()
- # print(cloth_feature.shape)
- # print(self.hidden_size)
-
- # print("cloth_feature.shape")
- # print(cloth_feature.shape)
- query_cloth = self.q_additional(cloth_feature)
- query_cloth = query_cloth.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_key = self.to_k_ip(cloth_text)
- ip_value = self.to_v_ip(cloth_text)
-
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # print(ip_value.shape)
- #$$ attn_mask?
- hidden_states_cloth = F.scaled_dot_product_attention(
- query_cloth, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states_cloth = hidden_states_cloth.to(query.dtype)
-
- ip_key = self.k_additional(hidden_states_cloth)
- ip_value = self.v_additional(hidden_states_cloth)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
- return hidden_states
-
-
-
-
-
-
-
-
-
-class IPAttnProcessor2_0_paint(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- if cross_attention_dim==None:
- print("cross_attention_dim is none")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- ):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
-
- # #######
-
- # # for ip-adapter
- ip_key = self.to_k_ip(encoder_hidden_states)
- ip_value = self.to_v_ip(encoder_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
-
-
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
-
- # #######
-
-
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-
-
-
-class IPAttnProcessor2_0_variant(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
@@ -1795,12 +124,7 @@ class IPAttnProcessor2_0_variant(torch.nn.Module):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -1822,42 +146,27 @@ class IPAttnProcessor2_0_variant(torch.nn.Module):
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
- # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- # hidden_states = hidden_states.to(query.dtype)
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- hidden_states, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- with torch.no_grad():
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ self.attn_map = ip_attention_probs
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
- hidden_states = ip_hidden_states
+ hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -1875,177 +184,20 @@ class IPAttnProcessor2_0_variant(torch.nn.Module):
return hidden_states
-
-class IPAttnProcessor2_0(torch.nn.Module):
+class AttnProcessor2_0(torch.nn.Module):
r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
+ def __init__(
self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale=1.0
+ hidden_size=None,
+ cross_attention_dim=None,
):
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- # args = (scale, )
- args = ()
-
- query = attn.to_q(hidden_states, *args)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states, *args)
- value = attn.to_v(encoder_hidden_states, *args)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- with torch.no_grad():
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-
-
-
-
-
-
-class IPAttnProcessor_referencenet_2_0(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4,attn_head_dim=10):
super().__init__()
-
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
- self.attn_head_dim=attn_head_dim
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
-
def __call__(
self,
attn,
@@ -2082,15 +234,8 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
- else:
- # get encoder_hidden_states, ip_hidden_states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
- if attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -2112,26 +257,6 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- with torch.no_grad():
- self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
@@ -2148,8 +273,7 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module):
return hidden_states
-
-class IPAttnProcessor2_0_Lora(torch.nn.Module):
+class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
@@ -2163,7 +287,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
The context length of the image features.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, scale_lora=1.0, rank = 4,num_tokens=4):
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@@ -2172,14 +296,10 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
- self.scale_lora = scale_lora
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_k_ip_lora = LoRALinearLayer(in_features=self.to_k_ip.in_features, out_features=self.to_k_ip.out_features, rank=rank)
- self.to_v_ip_lora =LoRALinearLayer(in_features=self.to_v_ip.in_features, out_features=self.to_v_ip.out_features, rank=rank)
-
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
@@ -2213,12 +333,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- if hasattr(attn,'q_lora'):
- query = attn.to_q(hidden_states)
- q_lora = attn.q_lora(hidden_states)
- query = query + self.scale_lora * q_lora
- else:
- query = attn.to_q(hidden_states)
+ query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -2232,21 +347,8 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- if hasattr(attn,'k_lora'):
- key = attn.to_k(encoder_hidden_states)
- k_lora = attn.k_lora(encoder_hidden_states)
- key = key + self.scale_lora * k_lora
- else:
- key = attn.to_k(encoder_hidden_states)
-
- if hasattr(attn,'v_lora'):
- value = attn.to_v(encoder_hidden_states)
- v_lora = attn.v_lora(encoder_hidden_states)
- value = value + self.scale_lora * v_lora
- else:
- value = attn.to_v(encoder_hidden_states)
-
-
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -2266,14 +368,8 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
-
ip_key = self.to_k_ip(ip_hidden_states)
- ip_key_lora = self.to_k_ip_lora(ip_hidden_states)
- ip_key = ip_key + self.scale_lora * ip_key_lora
ip_value = self.to_v_ip(ip_hidden_states)
- ip_value_lora = self.to_v_ip_lora(ip_hidden_states)
- ip_value = ip_value + self.scale_lora * ip_value_lora
-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -2283,6 +379,9 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
+ with torch.no_grad():
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+ #print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
@@ -2290,14 +389,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module):
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
-
- if hasattr(attn,'out_lora'):
- hidden_states = attn.to_out[0](hidden_states)
- out_lora = attn.out_lora(hidden_states)
- hidden_states = hidden_states+ self.scale_lora*out_lora
- else:
- hidden_states = attn.to_out[0](hidden_states)
-
+ hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc1ca1e5f45e1da21be543f8d28d5de2bb0eba9
--- /dev/null
+++ b/ip_adapter/attention_processor_faceid.py
@@ -0,0 +1,427 @@
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.models.lora import LoRALinearLayer
+
+
+class LoRAAttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ rank=4,
+ network_alpha=None,
+ lora_scale=1.0,
+ ):
+ super().__init__()
+
+ self.rank = rank
+ self.lora_scale = lora_scale
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAIPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.rank = rank
+ self.lora_scale = lora_scale
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ self.attn_map = ip_attention_probs
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAAttnProcessor2_0(nn.Module):
+
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ rank=4,
+ network_alpha=None,
+ lora_scale=1.0,
+ ):
+ super().__init__()
+
+ self.rank = rank
+ self.lora_scale = lora_scale
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAIPAttnProcessor2_0(nn.Module):
+ r"""
+ Processor for implementing the LoRA attention mechanism.
+
+ Args:
+ hidden_size (`int`, *optional*):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the `encoder_hidden_states`.
+ rank (`int`, defaults to 4):
+ The dimension of the LoRA update matrices.
+ network_alpha (`int`, *optional*):
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.rank = rank
+ self.lora_scale = lora_scale
+ self.num_tokens = num_tokens
+
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
+
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
+ #query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ # for text
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # for ip
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
diff --git a/ip_adapter/custom_pipelines.py b/ip_adapter/custom_pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d43d2c34db9b83f6148fac53425a9fd4c60fc93
--- /dev/null
+++ b/ip_adapter/custom_pipelines.py
@@ -0,0 +1,394 @@
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+
+from .utils import is_torch2_available
+
+if is_torch2_available():
+ from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
+else:
+ from .attention_processor import IPAttnProcessor
+
+
+class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
+ def set_scale(self, scale):
+ for attn_processor in self.unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ @torch.no_grad()
+ def __call__( # noqa: C901
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ control_guidance_start: float = 0.0,
+ control_guidance_end: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ control_guidance_start (`float`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # get init conditioning scale
+ for attn_processor in self.unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ conditioning_scale = attn_processor.scale
+ break
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):
+ self.set_scale(0.0)
+ else:
+ self.set_scale(conditioning_scale)
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if output_type != "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py
index d092feb1878aa29829af968091c61990002898ed..dcb8824aebae0554ee51b363d56a40d022edacdc 100644
--- a/ip_adapter/ip_adapter.py
+++ b/ip_adapter/ip_adapter.py
@@ -8,7 +8,7 @@ from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
-from .utils import is_torch2_available
+from .utils import is_torch2_available, get_generator
if is_torch2_available():
from .attention_processor import (
@@ -20,11 +20,9 @@ if is_torch2_available():
from .attention_processor import (
IPAttnProcessor2_0 as IPAttnProcessor,
)
- from .attention_processor import IPAttnProcessor2_0_Lora
-# else:
-# from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
+else:
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
from .resampler import Resampler
-from diffusers.models.lora import LoRALinearLayer
class ImageProjModel(torch.nn.Module):
@@ -33,6 +31,7 @@ class ImageProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
+ self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
@@ -123,49 +122,19 @@ class IPAdapter:
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
def load_ip_adapter(self):
- if self.ip_ckpt is not None:
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
- state_dict = {"image_proj": {}, "ip_adapter": {}}
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
- for key in f.keys():
- if key.startswith("image_proj."):
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
- elif key.startswith("ip_adapter."):
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
- else:
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
- ip_layers.load_state_dict(state_dict["ip_adapter"])
-
-
- # def load_ip_adapter(self):
- # if self.ip_ckpt is not None:
- # if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
- # state_dict = {"image_proj_model": {}, "ip_adapter": {}}
- # with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
- # for key in f.keys():
- # if key.startswith("image_proj_model."):
- # state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key)
- # elif key.startswith("ip_adapter."):
- # state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
- # else:
- # state_dict = torch.load(self.ip_ckpt, map_location="cpu")
-
- # tmp1 = {}
- # for k,v in state_dict.items():
- # if 'image_proj_model' in k:
- # tmp1[k.replace('image_proj_model.','')] = v
- # self.image_proj_model.load_state_dict(tmp1, strict=True)
- # # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
- # tmp2 = {}
- # for k,v in state_dict.ites():
- # if 'adapter_mode' in k:
- # tmp1[k] = v
-
- # print(ip_layers.state_dict())
- # ip_layers.load_state_dict(state_dict,strict=False)
-
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
@@ -180,19 +149,6 @@ class IPAdapter:
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
- def get_image_embeds_train(self, pil_image=None, clip_image_embeds=None):
- if pil_image is not None:
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
- else:
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
- return image_prompt_embeds, uncond_image_prompt_embeds
-
-
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
@@ -208,7 +164,7 @@ class IPAdapter:
num_samples=4,
seed=None,
guidance_scale=7.5,
- num_inference_steps=50,
+ num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
@@ -248,7 +204,8 @@ class IPAdapter:
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
+ generator = get_generator(seed, self.device)
+
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@@ -264,71 +221,6 @@ class IPAdapter:
class IPAdapterXL(IPAdapter):
"""SDXL"""
- def generate_test(
- self,
- pil_image,
- prompt=None,
- negative_prompt=None,
- scale=1.0,
- num_samples=4,
- seed=None,
- num_inference_steps=30,
- **kwargs,
- ):
- self.set_scale(scale)
-
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
-
- if prompt is None:
- prompt = "best quality, high quality"
- if negative_prompt is None:
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
-
- with torch.inference_mode():
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = self.pipe.encode_prompt(
- prompt,
- num_images_per_prompt=num_samples,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
-
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
- images = self.pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=num_inference_steps,
- generator=generator,
- **kwargs,
- ).images
-
-
- # with torch.autocast("cuda"):
- # images = self.pipe(
- # prompt_embeds=prompt_embeds,
- # negative_prompt_embeds=negative_prompt_embeds,
- # pooled_prompt_embeds=pooled_prompt_embeds,
- # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- # num_inference_steps=num_inference_steps,
- # generator=generator,
- # **kwargs,
- # ).images
-
- return images
-
-
def generate(
self,
pil_image,
@@ -376,98 +268,24 @@ class IPAdapterXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
+ self.generator = get_generator(seed, self.device)
+
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
- generator=generator,
+ generator=self.generator,
**kwargs,
).images
-
- # with torch.autocast("cuda"):
- # images = self.pipe(
- # prompt_embeds=prompt_embeds,
- # negative_prompt_embeds=negative_prompt_embeds,
- # pooled_prompt_embeds=pooled_prompt_embeds,
- # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- # num_inference_steps=num_inference_steps,
- # generator=generator,
- # **kwargs,
- # ).images
-
return images
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
- def generate(
- self,
- pil_image=None,
- clip_image_embeds=None,
- prompt=None,
- negative_prompt=None,
- scale=1.0,
- num_samples=4,
- seed=None,
- guidance_scale=7.5,
- num_inference_steps=50,
- **kwargs,
- ):
- self.set_scale(scale)
-
- if pil_image is not None:
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
- else:
- num_prompts = clip_image_embeds.size(0)
-
- if prompt is None:
- prompt = "best quality, high quality"
- if negative_prompt is None:
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
- pil_image=pil_image, clip_image=clip_image_embeds
- )
- bs_embed, seq_len, _ = image_prompt_embeds.shape
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
-
- with torch.inference_mode():
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
- prompt,
- device=self.device,
- num_images_per_prompt=num_samples,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
-
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
- images = self.pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_scale=guidance_scale,
- num_inference_steps=num_inference_steps,
- generator=generator,
- **kwargs,
- ).images
-
- return images
-
-
def init_proj(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
@@ -482,269 +300,12 @@ class IPAdapterPlus(IPAdapter):
return image_proj_model
@torch.inference_mode()
- def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
- if pil_image is not None:
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- else:
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
- uncond_clip_image_embeds = self.image_encoder(
- torch.zeros_like(clip_image), output_hidden_states=True
- ).hidden_states[-2]
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
- return image_prompt_embeds, uncond_image_prompt_embeds
-
-
-
-
-class IPAdapterPlus_Lora(IPAdapter):
- """IP-Adapter with fine-grained features"""
-
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32):
- self.rank = rank
- super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens)
-
-
- def generate(
- self,
- pil_image=None,
- clip_image_embeds=None,
- prompt=None,
- negative_prompt=None,
- scale=1.0,
- num_samples=4,
- seed=None,
- guidance_scale=7.5,
- num_inference_steps=50,
- **kwargs,
- ):
- self.set_scale(scale)
-
- if pil_image is not None:
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
- else:
- num_prompts = clip_image_embeds.size(0)
-
- if prompt is None:
- prompt = "best quality, high quality"
- if negative_prompt is None:
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
- pil_image=pil_image, clip_image=clip_image_embeds
- )
- bs_embed, seq_len, _ = image_prompt_embeds.shape
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
-
- with torch.inference_mode():
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
- prompt,
- device=self.device,
- num_images_per_prompt=num_samples,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
-
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
- images = self.pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_scale=guidance_scale,
- num_inference_steps=num_inference_steps,
- generator=generator,
- **kwargs,
- ).images
-
- return images
-
-
- def init_proj(self):
- image_proj_model = Resampler(
- dim=self.pipe.unet.config.cross_attention_dim,
- depth=4,
- dim_head=64,
- heads=12,
- num_queries=self.num_tokens,
- embedding_dim=self.image_encoder.config.hidden_size,
- output_dim=self.pipe.unet.config.cross_attention_dim,
- ff_mult=4,
- ).to(self.device, dtype=torch.float16)
- return image_proj_model
-
- @torch.inference_mode()
- def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
- if pil_image is not None:
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- else:
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
- uncond_clip_image_embeds = self.image_encoder(
- torch.zeros_like(clip_image), output_hidden_states=True
- ).hidden_states[-2]
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
- return image_prompt_embeds, uncond_image_prompt_embeds
-
- def set_ip_adapter(self):
- unet = self.pipe.unet
- attn_procs = {}
- unet_sd = unet.state_dict()
-
- for attn_processor_name, attn_processor in unet.attn_processors.items():
- # Parse the attention module.
- cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if attn_processor_name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif attn_processor_name.startswith("up_blocks"):
- block_id = int(attn_processor_name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif attn_processor_name.startswith("down_blocks"):
- block_id = int(attn_processor_name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- attn_procs[attn_processor_name] = AttnProcessor()
- else:
- layer_name = attn_processor_name.split(".processor")[0]
- weights = {
- "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
- "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
- }
- attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens)
- attn_procs[attn_processor_name].load_state_dict(weights,strict=False)
-
- attn_module = unet
- for n in attn_processor_name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
-
- attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank)
- attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank)
- attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank)
- attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank)
-
- unet.set_attn_processor(attn_procs)
- if hasattr(self.pipe, "controlnet"):
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
- for controlnet in self.pipe.controlnet.nets:
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
- else:
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
-
-
-
-class IPAdapterPlus_Lora_up(IPAdapter):
- """IP-Adapter with fine-grained features"""
-
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32):
- self.rank = rank
- super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens)
-
-
- def generate(
- self,
- pil_image=None,
- clip_image_embeds=None,
- prompt=None,
- negative_prompt=None,
- scale=1.0,
- num_samples=4,
- seed=None,
- guidance_scale=7.5,
- num_inference_steps=50,
- **kwargs,
- ):
- self.set_scale(scale)
-
- if pil_image is not None:
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
- else:
- num_prompts = clip_image_embeds.size(0)
-
- if prompt is None:
- prompt = "best quality, high quality"
- if negative_prompt is None:
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
- pil_image=pil_image, clip_image=clip_image_embeds
- )
- bs_embed, seq_len, _ = image_prompt_embeds.shape
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
-
- with torch.inference_mode():
- prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
- prompt,
- device=self.device,
- num_images_per_prompt=num_samples,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
- prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
- negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
-
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
- images = self.pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_scale=guidance_scale,
- num_inference_steps=num_inference_steps,
- generator=generator,
- **kwargs,
- ).images
-
- return images
-
-
- def init_proj(self):
- image_proj_model = Resampler(
- dim=self.pipe.unet.config.cross_attention_dim,
- depth=4,
- dim_head=64,
- heads=12,
- num_queries=self.num_tokens,
- embedding_dim=self.image_encoder.config.hidden_size,
- output_dim=self.pipe.unet.config.cross_attention_dim,
- ff_mult=4,
- ).to(self.device, dtype=torch.float16)
- return image_proj_model
-
- @torch.inference_mode()
- def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
- if pil_image is not None:
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- else:
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
@@ -752,55 +313,6 @@ class IPAdapterPlus_Lora_up(IPAdapter):
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
- def set_ip_adapter(self):
- unet = self.pipe.unet
- attn_procs = {}
- unet_sd = unet.state_dict()
-
- for attn_processor_name, attn_processor in unet.attn_processors.items():
- # Parse the attention module.
- cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if attn_processor_name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif attn_processor_name.startswith("up_blocks"):
- block_id = int(attn_processor_name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif attn_processor_name.startswith("down_blocks"):
- block_id = int(attn_processor_name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- attn_procs[attn_processor_name] = AttnProcessor()
- else:
- layer_name = attn_processor_name.split(".processor")[0]
- weights = {
- "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
- "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
- }
- attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens)
- attn_procs[attn_processor_name].load_state_dict(weights,strict=False)
-
- attn_module = unet
- for n in attn_processor_name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
-
-
- if "up_blocks" in attn_processor_name:
- attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank)
- attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank)
- attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank)
- attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank)
-
-
-
- unet.set_attn_processor(attn_procs)
- if hasattr(self.pipe, "controlnet"):
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
- for controlnet in self.pipe.controlnet.nets:
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
- else:
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
-
-
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter with full features"""
@@ -830,15 +342,12 @@ class IPAdapterPlusXL(IPAdapter):
return image_proj_model
@torch.inference_mode()
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
- if pil_image is not None:
- if isinstance(pil_image, Image.Image):
- pil_image = [pil_image]
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
- clip_image = clip_image.to(self.device, dtype=torch.float16)
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
- else:
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
+ def get_image_embeds(self, pil_image):
+ if isinstance(pil_image, Image.Image):
+ pil_image = [pil_image]
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
@@ -893,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter):
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
- generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
+ generator = get_generator(seed, self.device)
+
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe98ad540648d429a3a227733c8607626fe0784e
--- /dev/null
+++ b/ip_adapter/ip_adapter_faceid.py
@@ -0,0 +1,542 @@
+import os
+from typing import List
+
+import torch
+from diffusers import StableDiffusionPipeline
+from diffusers.pipelines.controlnet import MultiControlNetModel
+from PIL import Image
+from safetensors import safe_open
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
+from .utils import is_torch2_available, get_generator
+
+USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
+if is_torch2_available() and (not USE_DAFAULT_ATTN):
+ from .attention_processor_faceid import (
+ LoRAAttnProcessor2_0 as LoRAAttnProcessor,
+ )
+ from .attention_processor_faceid import (
+ LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
+ )
+else:
+ from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
+from .resampler import PerceiverAttention, FeedForward
+
+
+class FacePerceiverResampler(torch.nn.Module):
+ def __init__(
+ self,
+ *,
+ dim=768,
+ depth=4,
+ dim_head=64,
+ heads=16,
+ embedding_dim=1280,
+ output_dim=768,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
+ self.proj_out = torch.nn.Linear(dim, output_dim)
+ self.norm_out = torch.nn.LayerNorm(output_dim)
+ self.layers = torch.nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ torch.nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, latents, x):
+ x = self.proj_in(x)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class MLPProjModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, id_embeds):
+ x = self.proj(id_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ x = self.norm(x)
+ return x
+
+
+class ProjPlusModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ self.perceiver_resampler = FacePerceiverResampler(
+ dim=cross_attention_dim,
+ depth=4,
+ dim_head=64,
+ heads=cross_attention_dim // 64,
+ embedding_dim=clip_embeddings_dim,
+ output_dim=cross_attention_dim,
+ ff_mult=4,
+ )
+
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
+
+ x = self.proj(id_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ x = self.norm(x)
+ out = self.perceiver_resampler(x, clip_embeds)
+ if shortcut:
+ out = x + scale * out
+ return out
+
+
+class IPAdapterFaceID:
+ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
+ self.device = device
+ self.ip_ckpt = ip_ckpt
+ self.lora_rank = lora_rank
+ self.num_tokens = num_tokens
+ self.torch_dtype = torch_dtype
+
+ self.pipe = sd_pipe.to(self.device)
+ self.set_ip_adapter()
+
+ # image proj model
+ self.image_proj_model = self.init_proj()
+
+ self.load_ip_adapter()
+
+ def init_proj(self):
+ image_proj_model = MLPProjModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ id_embeddings_dim=512,
+ num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ return image_proj_model
+
+ def set_ip_adapter(self):
+ unet = self.pipe.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
+ ).to(self.device, dtype=self.torch_dtype)
+ else:
+ attn_procs[name] = LoRAIPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ unet.set_attn_processor(attn_procs)
+
+ def load_ip_adapter(self):
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
+
+ @torch.inference_mode()
+ def get_image_embeds(self, faceid_embeds):
+
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
+ image_prompt_embeds = self.image_proj_model(faceid_embeds)
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
+ attn_processor.scale = scale
+
+ def generate(
+ self,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
+ prompt,
+ device=self.device,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDPlus:
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
+ self.device = device
+ self.image_encoder_path = image_encoder_path
+ self.ip_ckpt = ip_ckpt
+ self.lora_rank = lora_rank
+ self.num_tokens = num_tokens
+ self.torch_dtype = torch_dtype
+
+ self.pipe = sd_pipe.to(self.device)
+ self.set_ip_adapter()
+
+ # load image encoder
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
+ self.device, dtype=self.torch_dtype
+ )
+ self.clip_image_processor = CLIPImageProcessor()
+ # image proj model
+ self.image_proj_model = self.init_proj()
+
+ self.load_ip_adapter()
+
+ def init_proj(self):
+ image_proj_model = ProjPlusModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ id_embeddings_dim=512,
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
+ num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ return image_proj_model
+
+ def set_ip_adapter(self):
+ unet = self.pipe.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
+ ).to(self.device, dtype=self.torch_dtype)
+ else:
+ attn_procs[name] = LoRAIPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ unet.set_attn_processor(attn_procs)
+
+ def load_ip_adapter(self):
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
+
+ @torch.inference_mode()
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
+ if isinstance(face_image, Image.Image):
+ pil_image = [face_image]
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
+ uncond_clip_image_embeds = self.image_encoder(
+ torch.zeros_like(clip_image), output_hidden_states=True
+ ).hidden_states[-2]
+
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
+ attn_processor.scale = scale
+
+ def generate(
+ self,
+ face_image=None,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ s_scale=1.0,
+ shortcut=False,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
+ prompt,
+ device=self.device,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDXL(IPAdapterFaceID):
+ """SDXL"""
+
+ def generate(
+ self,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
+ """SDXL"""
+
+ def generate(
+ self,
+ face_image=None,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ s_scale=1.0,
+ shortcut=True,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ **kwargs,
+ ).images
+
+ return images
diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ca84c4fdb56e9b0dfb56195426427b31f07fb8
--- /dev/null
+++ b/ip_adapter/ip_adapter_faceid_separate.py
@@ -0,0 +1,547 @@
+import os
+from typing import List
+
+import torch
+from diffusers import StableDiffusionPipeline
+from diffusers.pipelines.controlnet import MultiControlNetModel
+from PIL import Image
+from safetensors import safe_open
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from .utils import is_torch2_available, get_generator
+
+USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
+if is_torch2_available() and (not USE_DAFAULT_ATTN):
+ from .attention_processor import (
+ AttnProcessor2_0 as AttnProcessor,
+ )
+ from .attention_processor import (
+ IPAttnProcessor2_0 as IPAttnProcessor,
+ )
+else:
+ from .attention_processor import AttnProcessor, IPAttnProcessor
+from .resampler import PerceiverAttention, FeedForward
+
+
+class FacePerceiverResampler(torch.nn.Module):
+ def __init__(
+ self,
+ *,
+ dim=768,
+ depth=4,
+ dim_head=64,
+ heads=16,
+ embedding_dim=1280,
+ output_dim=768,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
+ self.proj_out = torch.nn.Linear(dim, output_dim)
+ self.norm_out = torch.nn.LayerNorm(output_dim)
+ self.layers = torch.nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ torch.nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, latents, x):
+ x = self.proj_in(x)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class MLPProjModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, id_embeds):
+ x = self.proj(id_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ x = self.norm(x)
+ return x
+
+
+class ProjPlusModel(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
+ )
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ self.perceiver_resampler = FacePerceiverResampler(
+ dim=cross_attention_dim,
+ depth=4,
+ dim_head=64,
+ heads=cross_attention_dim // 64,
+ embedding_dim=clip_embeddings_dim,
+ output_dim=cross_attention_dim,
+ ff_mult=4,
+ )
+
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
+
+ x = self.proj(id_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ x = self.norm(x)
+ out = self.perceiver_resampler(x, clip_embeds)
+ if shortcut:
+ out = x + scale * out
+ return out
+
+
+class IPAdapterFaceID:
+ def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, n_cond=1, torch_dtype=torch.float16):
+ self.device = device
+ self.ip_ckpt = ip_ckpt
+ self.num_tokens = num_tokens
+ self.n_cond = n_cond
+ self.torch_dtype = torch_dtype
+
+ self.pipe = sd_pipe.to(self.device)
+ self.set_ip_adapter()
+
+ # image proj model
+ self.image_proj_model = self.init_proj()
+
+ self.load_ip_adapter()
+
+ def init_proj(self):
+ image_proj_model = MLPProjModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ id_embeddings_dim=512,
+ num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ return image_proj_model
+
+ def set_ip_adapter(self):
+ unet = self.pipe.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens*self.n_cond,
+ ).to(self.device, dtype=self.torch_dtype)
+ unet.set_attn_processor(attn_procs)
+
+ def load_ip_adapter(self):
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
+
+ @torch.inference_mode()
+ def get_image_embeds(self, faceid_embeds):
+
+ multi_face = False
+ if faceid_embeds.dim() == 3:
+ multi_face = True
+ b, n, c = faceid_embeds.shape
+ faceid_embeds = faceid_embeds.reshape(b*n, c)
+
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
+ image_prompt_embeds = self.image_proj_model(faceid_embeds)
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
+ if multi_face:
+ c = image_prompt_embeds.size(-1)
+ image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c)
+
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, IPAttnProcessor):
+ attn_processor.scale = scale
+
+ def generate(
+ self,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
+ prompt,
+ device=self.device,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDPlus:
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16):
+ self.device = device
+ self.image_encoder_path = image_encoder_path
+ self.ip_ckpt = ip_ckpt
+ self.num_tokens = num_tokens
+ self.torch_dtype = torch_dtype
+
+ self.pipe = sd_pipe.to(self.device)
+ self.set_ip_adapter()
+
+ # load image encoder
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
+ self.device, dtype=self.torch_dtype
+ )
+ self.clip_image_processor = CLIPImageProcessor()
+ # image proj model
+ self.image_proj_model = self.init_proj()
+
+ self.load_ip_adapter()
+
+ def init_proj(self):
+ image_proj_model = ProjPlusModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ id_embeddings_dim=512,
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
+ num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ return image_proj_model
+
+ def set_ip_adapter(self):
+ unet = self.pipe.unet
+ attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor()
+ else:
+ attn_procs[name] = IPAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens,
+ ).to(self.device, dtype=self.torch_dtype)
+ unet.set_attn_processor(attn_procs)
+
+ def load_ip_adapter(self):
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
+
+ @torch.inference_mode()
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
+ if isinstance(face_image, Image.Image):
+ pil_image = [face_image]
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
+ uncond_clip_image_embeds = self.image_encoder(
+ torch.zeros_like(clip_image), output_hidden_states=True
+ ).hidden_states[-2]
+
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
+ return image_prompt_embeds, uncond_image_prompt_embeds
+
+ def set_scale(self, scale):
+ for attn_processor in self.pipe.unet.attn_processors.values():
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
+ attn_processor.scale = scale
+
+ def generate(
+ self,
+ face_image=None,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ s_scale=1.0,
+ shortcut=False,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
+ prompt,
+ device=self.device,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDXL(IPAdapterFaceID):
+ """SDXL"""
+
+ def generate(
+ self,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ num_inference_steps=30,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ **kwargs,
+ ).images
+
+ return images
+
+
+class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
+ """SDXL"""
+
+ def generate(
+ self,
+ face_image=None,
+ faceid_embeds=None,
+ prompt=None,
+ negative_prompt=None,
+ scale=1.0,
+ num_samples=4,
+ seed=None,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ s_scale=1.0,
+ shortcut=True,
+ **kwargs,
+ ):
+ self.set_scale(scale)
+
+ num_prompts = faceid_embeds.size(0)
+
+ if prompt is None:
+ prompt = "best quality, high quality"
+ if negative_prompt is None:
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+ if not isinstance(prompt, List):
+ prompt = [prompt] * num_prompts
+ if not isinstance(negative_prompt, List):
+ negative_prompt = [negative_prompt] * num_prompts
+
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
+
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
+
+ with torch.inference_mode():
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_samples,
+ do_classifier_free_guidance=True,
+ negative_prompt=negative_prompt,
+ )
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
+
+ generator = get_generator(seed, self.device)
+
+ images = self.pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ **kwargs,
+ ).images
+
+ return images
diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py
index 708740b4a3abbe08030e4428c39acd7438c1006d..24266671d02092438ae6576336a59659fef9c054 100644
--- a/ip_adapter/resampler.py
+++ b/ip_adapter/resampler.py
@@ -78,54 +78,6 @@ class PerceiverAttention(nn.Module):
return self.to_out(out)
-class CrossAttention(nn.Module):
- def __init__(self, *, dim, dim_head=64, heads=8):
- super().__init__()
- self.scale = dim_head**-0.5
- self.dim_head = dim_head
- self.heads = heads
- inner_dim = dim_head * heads
-
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
-
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_k = nn.Linear(dim, inner_dim, bias=False)
- self.to_v = nn.Linear(dim, inner_dim, bias=False)
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
-
-
- def forward(self, x, x2):
- """
- Args:
- x (torch.Tensor): image features
- shape (b, n1, D)
- latent (torch.Tensor): latent features
- shape (b, n2, D)
- """
- x = self.norm1(x)
- x2 = self.norm2(x2)
-
- b, l, _ = x2.shape
-
- q = self.to_q(x)
- k = self.to_k(x2)
- v = self.to_v(x2)
-
- q = reshape_tensor(q, self.heads)
- k = reshape_tensor(k, self.heads)
- v = reshape_tensor(v, self.heads)
-
- # attention
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
- out = weight @ v
-
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
- return self.to_out(out)
-
-
class Resampler(nn.Module):
def __init__(
self,
@@ -142,6 +94,7 @@ class Resampler(nn.Module):
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
):
super().__init__()
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
@@ -150,6 +103,16 @@ class Resampler(nn.Module):
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
+ self.to_latents_from_mean_pooled_seq = (
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * num_latents_mean_pooled),
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+ )
+ if num_latents_mean_pooled > 0
+ else None
+ )
+
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
@@ -162,11 +125,19 @@ class Resampler(nn.Module):
)
def forward(self, x):
+ if self.pos_emb is not None:
+ n, device = x.shape[1], x.device
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
+ x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
+ if self.to_latents_from_mean_pooled_seq:
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
@@ -176,7 +147,6 @@ class Resampler(nn.Module):
return self.norm_out(latents)
-
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py
index 9a105f3701c15e8d3bbf838d79bacc51e91d0696..6a273358585962fdf383d0bb7a0e1c654b4999b8 100644
--- a/ip_adapter/utils.py
+++ b/ip_adapter/utils.py
@@ -1,5 +1,93 @@
+import torch
import torch.nn.functional as F
+import numpy as np
+from PIL import Image
+attn_maps = {}
+def hook_fn(name):
+ def forward_hook(module, input, output):
+ if hasattr(module.processor, "attn_map"):
+ attn_maps[name] = module.processor.attn_map
+ del module.processor.attn_map
+ return forward_hook
+
+def register_cross_attention_hook(unet):
+ for name, module in unet.named_modules():
+ if name.split('.')[-1].startswith('attn2'):
+ module.register_forward_hook(hook_fn(name))
+
+ return unet
+
+def upscale(attn_map, target_size):
+ attn_map = torch.mean(attn_map, dim=0)
+ attn_map = attn_map.permute(1,0)
+ temp_size = None
+
+ for i in range(0,5):
+ scale = 2 ** i
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
+ break
+
+ assert temp_size is not None, "temp_size cannot is None"
+
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
+
+ attn_map = F.interpolate(
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
+ size=target_size,
+ mode='bilinear',
+ align_corners=False
+ )[0]
+
+ attn_map = torch.softmax(attn_map, dim=0)
+ return attn_map
+def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
+
+ idx = 0 if instance_or_negative else 1
+ net_attn_maps = []
+
+ for name, attn_map in attn_maps.items():
+ attn_map = attn_map.cpu() if detach else attn_map
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
+ attn_map = upscale(attn_map, image_size)
+ net_attn_maps.append(attn_map)
+
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
+
+ return net_attn_maps
+
+def attnmaps2images(net_attn_maps):
+
+ #total_attn_scores = 0
+ images = []
+
+ for attn_map in net_attn_maps:
+ attn_map = attn_map.cpu().numpy()
+ #total_attn_scores += attn_map.mean().item()
+
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
+ #print("norm: ", normalized_attn_map.shape)
+ image = Image.fromarray(normalized_attn_map)
+
+ #image = fix_save_attn_map(attn_map)
+ images.append(image)
+
+ #print(total_attn_scores)
+ return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
+
+def get_generator(seed, device):
+
+ if seed is not None:
+ if isinstance(seed, list):
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
+ else:
+ generator = torch.Generator(device).manual_seed(seed)
+ else:
+ generator = None
+
+ return generator
\ No newline at end of file
diff --git a/preprocess/humanparsing/run_parsing.py b/preprocess/humanparsing/run_parsing.py
index db9a629eb6c16ecceb005d5610229cad91f71a6b..14028467468d280139329e1f197e63df886d1fb3 100644
--- a/preprocess/humanparsing/run_parsing.py
+++ b/preprocess/humanparsing/run_parsing.py
@@ -11,12 +11,12 @@ import torch
class Parsing:
def __init__(self, gpu_id: int):
- self.gpu_id = gpu_id
- torch.cuda.set_device(gpu_id)
+ # self.gpu_id = gpu_id
+ # torch.cuda.set_device(gpu_id)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
- session_options.add_session_config_entry('gpu_id', str(gpu_id))
+ # session_options.add_session_config_entry('gpu_id', str(gpu_id))
self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'ckpt/humanparsing/parsing_atr.onnx'),
sess_options=session_options, providers=['CPUExecutionProvider'])
self.lip_session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'ckpt/humanparsing/parsing_lip.onnx'),
diff --git a/preprocess/openpose/annotator/openpose/body.py b/preprocess/openpose/annotator/openpose/body.py
index 7af77f0409f9aa5195e668a7eb782267ea0889c3..27012f28ad0e736f7e7bd877f35db3b64fe8bd9c 100644
--- a/preprocess/openpose/annotator/openpose/body.py
+++ b/preprocess/openpose/annotator/openpose/body.py
@@ -20,9 +20,9 @@ from .model import bodypose_model
class Body(object):
def __init__(self, model_path):
self.model = bodypose_model()
- if torch.cuda.is_available():
- self.model = self.model.cuda()
- # print('cuda')
+ # if torch.cuda.is_available():
+ # self.model = self.model.cuda()
+ # print('cuda')
model_dict = util.transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()
diff --git a/preprocess/openpose/run_openpose.py b/preprocess/openpose/run_openpose.py
index 37e8ee0d2411a218b0ffba01a190ab0e56dae06e..fa0ed1fe39e8c871726555f184f1dcc3c8dd51bc 100644
--- a/preprocess/openpose/run_openpose.py
+++ b/preprocess/openpose/run_openpose.py
@@ -1,6 +1,6 @@
import pdb
-# import config
+import config
from pathlib import Path
import sys
@@ -28,12 +28,12 @@ import pdb
class OpenPose:
def __init__(self, gpu_id: int):
- self.gpu_id = gpu_id
- torch.cuda.set_device(gpu_id)
+ # self.gpu_id = gpu_id
+ # torch.cuda.set_device(gpu_id)
self.preprocessor = OpenposeDetector()
def __call__(self, input_image, resolution=384):
- torch.cuda.set_device(self.gpu_id)
+ # torch.cuda.set_device(self.gpu_id)
if isinstance(input_image, Image.Image):
input_image = np.asarray(input_image)
elif type(input_image) == str:
diff --git a/requirements.txt b/requirements.txt
index e0182cb3ac961882217447e9a53550f5562207b5..a9243fb4b897c3cc398fda901f58f4978a0eba28 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,18 +1,23 @@
-huggingface_hub==0.20.2
-accelerate==0.25.0
-torchmetrics==1.2.1
-tqdm==4.66.1
-transformers==4.36.2
-diffusers==0.25.0
-einops==0.7.0
-bitsandbytes==0.39.0
-scipy==1.11.1
-opencv-python
-gradio==4.24.0
-fvcore
-cloudpickle
-omegaconf
-pycocotools
-basicsr
-av
+transformers==4.36.2
+torch==2.0.1
+torchvision==0.15.2
+torchaudio==2.0.2
+numpy==1.24.4
+scipy==1.10.1
+scikit-image==0.21.0
+opencv-python==4.7.0.72
+pillow==9.4.0
+diffusers==0.25.0
+transformers==4.36.2
+accelerate==0.26.1
+matplotlib==3.7.4
+tqdm==4.64.1
+config==0.5.1
+einops==0.7.0
onnxruntime==1.16.2
+basicsr
+av
+fvcore
+cloudpickle
+omegaconf
+pycocotools
\ No newline at end of file
diff --git a/src/attentionhacked_tryon.py b/src/attentionhacked_tryon.py
index 1e5123ca4e07786bcbb28769b0fc1323c6db3571..9947e7c878f9f1fd3701b939dfc243628fa54652 100644
--- a/src/attentionhacked_tryon.py
+++ b/src/attentionhacked_tryon.py
@@ -331,6 +331,7 @@ class BasicTransformerBlock(nn.Module):
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+ #type2
modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
curr_garment_feat_idx +=1
attn_output = self.attn1(
@@ -345,6 +346,8 @@ class BasicTransformerBlock(nn.Module):
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
+
+ #type2
hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
diff --git a/src/tryon_pipeline.py b/src/tryon_pipeline.py
index 397ad18e462ac687858229fdab2a992ab5a722b7..78f22375101a3312a6e3f992126f92fa91839ed4 100644
--- a/src/tryon_pipeline.py
+++ b/src/tryon_pipeline.py
@@ -370,7 +370,6 @@ class StableDiffusionXLInpaintPipeline(
"text_encoder_2",
"image_encoder",
"feature_extractor",
- "unet_encoder",
]
_callback_tensor_inputs = [
"latents",
@@ -392,7 +391,6 @@ class StableDiffusionXLInpaintPipeline(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
- unet_encoder: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
@@ -409,7 +407,6 @@ class StableDiffusionXLInpaintPipeline(
tokenizer_2=tokenizer_2,
unet=unet,
image_encoder=image_encoder,
- unet_encoder=unet_encoder,
feature_extractor=feature_extractor,
scheduler=scheduler,
)
@@ -1722,8 +1719,8 @@ class StableDiffusionXLInpaintPipeline(
ip_adapter_image, device, batch_size * num_images_per_prompt
)
- #project outside for loop
- image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
+ #project outside for loop
+ image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
# 11. Denoising loop
diff --git a/train_xl.py b/train_xl.py
deleted file mode 100644
index 970bf47e3f50bcda97b5132eb39f407659df643d..0000000000000000000000000000000000000000
--- a/train_xl.py
+++ /dev/null
@@ -1,797 +0,0 @@
-import os
-import random
-import argparse
-import json
-import itertools
-import torch
-import torch.nn.functional as F
-from torchvision import transforms
-from PIL import Image
-from transformers import CLIPImageProcessor
-from accelerate import Accelerator
-from accelerate.utils import ProjectConfiguration
-from diffusers import AutoencoderKL, DDPMScheduler
-from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
-
-from src.unet_hacked_tryon import UNet2DConditionModel
-from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
-from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
-
-from ip_adapter.ip_adapter import Resampler
-from diffusers.utils.import_utils import is_xformers_available
-from typing import Literal, Tuple,List
-import torch.utils.data as data
-import math
-from tqdm.auto import tqdm
-from diffusers.training_utils import compute_snr
-import torchvision.transforms.functional as TF
-
-
-
-class VitonHDDataset(data.Dataset):
- def __init__(
- self,
- dataroot_path: str,
- phase: Literal["train", "test"],
- order: Literal["paired", "unpaired"] = "paired",
- size: Tuple[int, int] = (512, 384),
- ):
- super(VitonHDDataset, self).__init__()
- self.dataroot = dataroot_path
- self.phase = phase
- self.height = size[0]
- self.width = size[1]
- self.size = size
-
-
- self.norm = transforms.Normalize([0.5], [0.5])
- self.transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
- self.transform2D = transforms.Compose(
- [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
- )
- self.toTensor = transforms.ToTensor()
-
- with open(
- os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r"
- ) as file1:
- data1 = json.load(file1)
-
- annotation_list = [
- # "colors",
- # "textures",
- "sleeveLength",
- "neckLine",
- "item",
- ]
-
- self.annotation_pair = {}
- for k, v in data1.items():
- for elem in v:
- annotation_str = ""
- for template in annotation_list:
- for tag in elem["tag_info"]:
- if (
- tag["tag_name"] == template
- and tag["tag_category"] is not None
- ):
- annotation_str += tag["tag_category"]
- annotation_str += " "
- self.annotation_pair[elem["file_name"]] = annotation_str
-
-
- self.order = order
-
- self.toTensor = transforms.ToTensor()
-
- im_names = []
- c_names = []
- dataroot_names = []
-
-
- if phase == "train":
- filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
- else:
- filename = os.path.join(dataroot_path, f"{phase}_pairs.txt")
-
- with open(filename, "r") as f:
- for line in f.readlines():
- if phase == "train":
- im_name, _ = line.strip().split()
- c_name = im_name
- else:
- if order == "paired":
- im_name, _ = line.strip().split()
- c_name = im_name
- else:
- im_name, c_name = line.strip().split()
-
- im_names.append(im_name)
- c_names.append(c_name)
- dataroot_names.append(dataroot_path)
-
- self.im_names = im_names
- self.c_names = c_names
- self.dataroot_names = dataroot_names
- self.flip_transform = transforms.RandomHorizontalFlip(p=1)
- self.clip_processor = CLIPImageProcessor()
- def __getitem__(self, index):
- c_name = self.c_names[index]
- im_name = self.im_names[index]
- # subject_txt = self.txt_preprocess['train']("shirt")
- if c_name in self.annotation_pair:
- cloth_annotation = self.annotation_pair[c_name]
- else:
- cloth_annotation = "shirts"
-
- cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name))
-
- im_pil_big = Image.open(
- os.path.join(self.dataroot, self.phase, "image", im_name)
- ).resize((self.width,self.height))
-
- image = self.transform(im_pil_big)
- # load parsing image
-
-
- mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height))
- mask = self.toTensor(mask)
- mask = mask[:1]
- densepose_name = im_name
- densepose_map = Image.open(
- os.path.join(self.dataroot, self.phase, "image-densepose", densepose_name)
- )
- pose_img = self.toTensor(densepose_map) # [-1,1]
-
-
-
- if self.phase == "train":
- if random.random() > 0.5:
- cloth = self.flip_transform(cloth)
- mask = self.flip_transform(mask)
- image = self.flip_transform(image)
- pose_img = self.flip_transform(pose_img)
-
-
-
- if random.random()>0.5:
- color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5)
- fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue)
-
- image = TF.adjust_contrast(image, c)
- image = TF.adjust_brightness(image, b)
- image = TF.adjust_hue(image, h)
- image = TF.adjust_saturation(image, s)
-
- cloth = TF.adjust_contrast(cloth, c)
- cloth = TF.adjust_brightness(cloth, b)
- cloth = TF.adjust_hue(cloth, h)
- cloth = TF.adjust_saturation(cloth, s)
-
-
- if random.random() > 0.5:
- scale_val = random.uniform(0.8, 1.2)
- image = transforms.functional.affine(
- image, angle=0, translate=[0, 0], scale=scale_val, shear=0
- )
- mask = transforms.functional.affine(
- mask, angle=0, translate=[0, 0], scale=scale_val, shear=0
- )
- pose_img = transforms.functional.affine(
- pose_img, angle=0, translate=[0, 0], scale=scale_val, shear=0
- )
-
-
-
- if random.random() > 0.5:
- shift_valx = random.uniform(-0.2, 0.2)
- shift_valy = random.uniform(-0.2, 0.2)
- image = transforms.functional.affine(
- image,
- angle=0,
- translate=[shift_valx * image.shape[-1], shift_valy * image.shape[-2]],
- scale=1,
- shear=0,
- )
- mask = transforms.functional.affine(
- mask,
- angle=0,
- translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]],
- scale=1,
- shear=0,
- )
- pose_img = transforms.functional.affine(
- pose_img,
- angle=0,
- translate=[
- shift_valx * pose_img.shape[-1],
- shift_valy * pose_img.shape[-2],
- ],
- scale=1,
- shear=0,
- )
-
-
-
-
- mask = 1-mask
-
- cloth_trim = self.clip_processor(images=cloth, return_tensors="pt").pixel_values
-
-
- mask[mask < 0.5] = 0
- mask[mask >= 0.5] = 1
-
- im_mask = image * mask
-
- pose_img = self.norm(pose_img)
-
-
- result = {}
- result["c_name"] = c_name
- result["image"] = image
- result["cloth"] = cloth_trim
- result["cloth_pure"] = self.transform(cloth)
- result["inpaint_mask"] = 1-mask
- result["im_mask"] = im_mask
- result["caption"] = "model is wearing " + cloth_annotation
- result["caption_cloth"] = "a photo of " + cloth_annotation
- result["annotation"] = cloth_annotation
- result["pose_img"] = pose_img
-
-
- return result
-
- def __len__(self):
- return len(self.im_names)
-
-
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
- parser.add_argument("--pretrained_model_name_or_path",type=str,default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",)
- parser.add_argument("--pretrained_garmentnet_path",type=str,default="stabilityai/stable-diffusion-xl-base-1.0",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",)
- parser.add_argument("--checkpointing_epoch",type=int,default=10,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),)
- parser.add_argument("--pretrained_ip_adapter_path",type=str,default="ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin",help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",)
- parser.add_argument("--image_encoder_path",type=str,default="ckpt/image_encoder",required=False,help="Path to CLIP image encoder",)
- parser.add_argument("--gradient_checkpointing",action="store_true",help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",)
- parser.add_argument("--width",type=int,default=768,)
- parser.add_argument("--height",type=int,default=1024,)
- parser.add_argument("--gradient_accumulation_steps",type=int,default=1,help="Number of updates steps to accumulate before performing a backward/update pass.",)
- parser.add_argument("--logging_steps",type=int,default=1000,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),)
- parser.add_argument("--output_dir",type=str,default="output",help="The output directory where the model predictions and checkpoints will be written.",)
- parser.add_argument("--snr_gamma",type=float,default=None,help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ""More details here: https://arxiv.org/abs/2303.09556.",)
- parser.add_argument("--num_tokens",type=int,default=16,help=("IP adapter token nums"),)
- parser.add_argument("--learning_rate",type=float,default=1e-5,help="Learning rate to use.",)
- parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
- parser.add_argument("--train_batch_size", type=int, default=6, help="Batch size (per device) for the training dataloader.")
- parser.add_argument("--test_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
- parser.add_argument("--num_train_epochs", type=int, default=130)
- parser.add_argument("--max_train_steps",type=int,default=None,help="Total number of training steps to perform. If provided, overrides num_train_epochs.",)
- parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
- parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.")
- parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
- parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],help=("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),)
- parser.add_argument("--guidance_scale",type=float,default=2.0,)
- parser.add_argument("--seed", type=int, default=42,)
- parser.add_argument("--num_inference_steps",type=int,default=30,)
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
- parser.add_argument("--data_dir", type=str, default="/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando", help="For distributed training: local_rank")
-
- args = parser.parse_args()
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
- if env_local_rank != -1 and env_local_rank != args.local_rank:
- args.local_rank = env_local_rank
-
- return args
-
-
-
-
-
-def main():
-
-
- args = parse_args()
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir)
- accelerator = Accelerator(
- mixed_precision=args.mixed_precision,
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- project_config=accelerator_project_config,
- )
-
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- # Load scheduler, tokenizer and models.
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler",rescale_betas_zero_snr=True)
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
- tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
- text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,subfolder="vae",torch_dtype=torch.float16,)
- unet_encoder = UNet2DConditionModel_ref.from_pretrained(args.pretrained_garmentnet_path, subfolder="unet")
- unet_encoder.config.addition_embed_type = None
- unet_encoder.config["addition_embed_type"] = None
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
-
- #customize unet start
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",low_cpu_mem_usage=False, device_map=None)
- unet.config.encoder_hid_dim = image_encoder.config.hidden_size
- unet.config.encoder_hid_dim_type = "ip_image_proj"
- unet.config["encoder_hid_dim"] = image_encoder.config.hidden_size
- unet.config["encoder_hid_dim_type"] = "ip_image_proj"
-
-
- state_dict = torch.load(args.pretrained_ip_adapter_path, map_location="cpu")
-
-
- adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
- adapter_modules.load_state_dict(state_dict["ip_adapter"],strict=True)
-
- #ip-adapter
- image_proj_model = Resampler(
- dim=image_encoder.config.hidden_size,
- depth=4,
- dim_head=64,
- heads=20,
- num_queries=args.num_tokens,
- embedding_dim=image_encoder.config.hidden_size,
- output_dim=unet.config.cross_attention_dim,
- ff_mult=4,
- ).to(accelerator.device, dtype=torch.float32)
-
- image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
- image_proj_model.requires_grad_(True)
-
- unet.encoder_hid_proj = image_proj_model
-
- conv_new = torch.nn.Conv2d(
- in_channels=4+4+1+4,
- out_channels=unet.conv_in.out_channels,
- kernel_size=3,
- padding=1,
- )
- torch.nn.init.kaiming_normal_(conv_new.weight)
- conv_new.weight.data = conv_new.weight.data * 0.
-
- conv_new.weight.data[:, :9] = unet.conv_in.weight.data
- conv_new.bias.data = unet.conv_in.bias.data
-
- unet.conv_in = conv_new # replace conv layer in unet
- unet.config['in_channels'] = 13 # update config
- unet.config.in_channels = 13 # update config
- #customize unet end
-
-
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
- vae.to(accelerator.device)
- text_encoder.to(accelerator.device, dtype=weight_dtype)
- text_encoder_2.to(accelerator.device, dtype=weight_dtype)
- image_encoder.to(accelerator.device, dtype=weight_dtype)
- unet_encoder.to(accelerator.device, dtype=weight_dtype)
-
-
- vae.requires_grad_(False)
- text_encoder.requires_grad_(False)
- text_encoder_2.requires_grad_(False)
- image_encoder.requires_grad_(False)
- unet_encoder.requires_grad_(False)
- unet.requires_grad_(True)
-
-
-
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
-
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- unet_encoder.enable_gradient_checkpointing()
- unet.train()
-
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
-
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- params_to_opt = itertools.chain(unet.parameters())
-
-
- optimizer = optimizer_class(
- params_to_opt,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
-
- train_dataset = VitonHDDataset(
- dataroot_path=args.data_dir,
- phase="train",
- order="paired",
- size=(args.height, args.width),
- )
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- pin_memory=True,
- shuffle=False,
- batch_size=args.train_batch_size,
- num_workers=16,
- )
- test_dataset = VitonHDDataset(
- dataroot_path=args.data_dir,
- phase="test",
- order="paired",
- size=(args.height, args.width),
- )
- test_dataloader = torch.utils.data.DataLoader(
- test_dataset,
- shuffle=False,
- batch_size=args.test_batch_size,
- num_workers=4,
- )
-
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
-
- unet,image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader = accelerator.prepare(unet, image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader)
- initial_global_step = 0
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
- # Train!
- progress_bar = tqdm(
- range(0, args.max_train_steps),
- initial=initial_global_step,
- desc="Steps",
- # Only show the progress bar once on each machine.
- disable=not accelerator.is_local_main_process,
- )
- global_step = 0
- first_epoch = 0
- train_loss=0.0
- for epoch in range(first_epoch, args.num_train_epochs):
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet), accelerator.accumulate(image_proj_model):
- if global_step % args.logging_steps == 0:
- if accelerator.is_main_process:
- with torch.no_grad():
- with torch.cuda.amp.autocast():
- unwrapped_unet= accelerator.unwrap_model(unet)
- newpipe = TryonPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unwrapped_unet,
- vae= vae,
- scheduler=noise_scheduler,
- tokenizer=tokenizer,
- tokenizer_2=tokenizer_2,
- text_encoder=text_encoder,
- text_encoder_2=text_encoder_2,
- image_encoder=image_encoder,
- unet_encoder = unet_encoder,
- torch_dtype=torch.float16,
- add_watermarker=False,
- safety_checker=None,
- ).to(accelerator.device)
- with torch.no_grad():
- for sample in test_dataloader:
- img_emb_list = []
- for i in range(sample['cloth'].shape[0]):
- img_emb_list.append(sample['cloth'][i])
-
- prompt = sample["caption"]
-
- num_prompts = sample['cloth'].shape[0]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
- image_embeds = torch.cat(img_emb_list,dim=0)
-
-
- with torch.inference_mode():
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = newpipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=True,
- negative_prompt=negative_prompt,
- )
-
-
- prompt = sample["caption_cloth"]
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
-
- if not isinstance(prompt, List):
- prompt = [prompt] * num_prompts
- if not isinstance(negative_prompt, List):
- negative_prompt = [negative_prompt] * num_prompts
-
-
- with torch.inference_mode():
- (
- prompt_embeds_c,
- _,
- _,
- _,
- ) = newpipe.encode_prompt(
- prompt,
- num_images_per_prompt=1,
- do_classifier_free_guidance=False,
- negative_prompt=negative_prompt,
- )
-
-
-
- generator = torch.Generator(newpipe.device).manual_seed(args.seed) if args.seed is not None else None
- images = newpipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=args.num_inference_steps,
- generator=generator,
- strength = 1.0,
- pose_img = sample['pose_img'],
- text_embeds_cloth=prompt_embeds_c,
- cloth = sample["cloth_pure"].to(accelerator.device),
- mask_image=sample['inpaint_mask'],
- image=(sample['image']+1.0)/2.0,
- height=args.height,
- width=args.width,
- guidance_scale=args.guidance_scale,
- ip_adapter_image = image_embeds,
- )[0]
-
- for i in range(len(images)):
- images[i].save(os.path.join(args.output_dir,str(global_step)+"_"+str(i)+"_"+"test.jpg"))
- break
- del unwrapped_unet
- del newpipe
- torch.cuda.empty_cache()
-
-
-
- pixel_values = batch["image"].to(dtype=vae.dtype)
- model_input = vae.encode(pixel_values).latent_dist.sample()
- model_input = model_input * vae.config.scaling_factor
-
- masked_latents = vae.encode(
- batch["im_mask"].reshape(batch["image"].shape).to(dtype=vae.dtype)
- ).latent_dist.sample()
- masked_latents = masked_latents * vae.config.scaling_factor
- masks = batch["inpaint_mask"]
- # resize the mask to latents shape as we concatenate the mask to the latents
- mask = torch.stack(
- [
- torch.nn.functional.interpolate(masks, size=(args.height // 8, args.width // 8))
- ]
- )
- mask = mask.reshape(-1, 1, args.height // 8, args.width // 8)
-
- pose_map = vae.encode(batch["pose_img"].to(dtype=vae.dtype)).latent_dist.sample()
- pose_map = pose_map * vae.config.scaling_factor
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(model_input)
-
- bsz = model_input.shape[0]
- timesteps = torch.randint(
- 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
- )
- # Add noise to the latents according to the noise magnitude at each timestep
- noisy_latents = noise_scheduler.add_noise(model_input, noise, timesteps)
- latent_model_input = torch.cat([noisy_latents, mask,masked_latents,pose_map], dim=1)
-
-
- text_input_ids = tokenizer(
- batch['caption'],
- max_length=tokenizer.model_max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt"
- ).input_ids
- text_input_ids_2 = tokenizer_2(
- batch['caption'],
- max_length=tokenizer_2.model_max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt"
- ).input_ids
-
- encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True)
- text_embeds = encoder_output.hidden_states[-2]
- encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True)
- pooled_text_embeds = encoder_output_2[0]
- text_embeds_2 = encoder_output_2.hidden_states[-2]
- encoder_hidden_states = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
-
-
- def compute_time_ids(original_size, crops_coords_top_left = (0,0)):
- # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
- target_size = (args.height, args.height)
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
- add_time_ids = torch.tensor([add_time_ids])
- add_time_ids = add_time_ids.to(accelerator.device)
- return add_time_ids
-
- add_time_ids = torch.cat(
- [compute_time_ids((args.height, args.height)) for i in range(bsz)]
- )
-
- img_emb_list = []
- for i in range(bsz):
- img_emb_list.append(batch['cloth'][i])
-
- image_embeds = torch.cat(img_emb_list,dim=0)
- image_embeds = image_encoder(image_embeds, output_hidden_states=True).hidden_states[-2]
- ip_tokens =image_proj_model(image_embeds)
-
-
-
- # add cond
- unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
- unet_added_cond_kwargs["image_embeds"] = ip_tokens
-
- cloth_values = batch["cloth_pure"].to(accelerator.device,dtype=vae.dtype)
- cloth_values = vae.encode(cloth_values).latent_dist.sample()
- cloth_values = cloth_values * vae.config.scaling_factor
-
-
- text_input_ids = tokenizer(
- batch['caption_cloth'],
- max_length=tokenizer.model_max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt"
- ).input_ids
- text_input_ids_2 = tokenizer_2(
- batch['caption_cloth'],
- max_length=tokenizer_2.model_max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt"
- ).input_ids
-
-
- encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True)
- text_embeds_cloth = encoder_output.hidden_states[-2]
- encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True)
- text_embeds_2_cloth = encoder_output_2.hidden_states[-2]
- text_embeds_cloth = torch.concat([text_embeds_cloth, text_embeds_2_cloth], dim=-1) # concat
-
-
- down,reference_features = unet_encoder(cloth_values,timesteps, text_embeds_cloth,return_dict=False)
- reference_features = list(reference_features)
-
- noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states,added_cond_kwargs=unet_added_cond_kwargs,garment_features=reference_features).sample
-
-
- if noise_scheduler.config.prediction_type == "epsilon":
- target = noise
- elif noise_scheduler.config.prediction_type == "v_prediction":
- target = noise_scheduler.get_velocity(model_input, noise, timesteps)
- elif noise_scheduler.config.prediction_type == "sample":
- # We set the target to latents here, but the model_pred will return the noise sample prediction.
- target = model_input
- # We will have to subtract the noise residual from the prediction to get the target sample.
- model_pred = model_pred - noise
- else:
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
-
- if args.snr_gamma is None:
- loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Since we predict the noise instead of x_0, the original formulation is slightly changed.
- # This is discussed in Section 4.2 of the same paper.
- snr = compute_snr(noise_scheduler, timesteps)
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective requires that we add one to SNR values before we divide by them.
- snr = snr + 1
- mse_loss_weights = (
- torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
-
- loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
-
- avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
- train_loss += avg_loss.item() / args.gradient_accumulation_steps
-
-
- # Backpropagate
- accelerator.backward(loss)
-
- if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(params_to_opt, 1.0)
-
- optimizer.step()
- optimizer.zero_grad()
- # Load scheduler, tokenizer and models.
- progress_bar.update(1)
- global_step += 1
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
- accelerator.log({"train_loss": train_loss}, step=global_step)
- train_loss = 0.0
- logs = {"step_loss": loss.detach().item()}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- if global_step % args.checkpointing_epoch == 0:
- if accelerator.is_main_process:
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- unwrapped_unet = accelerator.unwrap_model(
- unet, keep_fp32_wrapper=True
- )
- pipeline = TryonPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unwrapped_unet,
- vae= vae,
- scheduler=noise_scheduler,
- tokenizer=tokenizer,
- tokenizer_2=tokenizer_2,
- text_encoder=text_encoder,
- text_encoder_2=text_encoder_2,
- image_encoder=image_encoder,
- unet_encoder=unet_encoder,
- torch_dtype=torch.float16,
- add_watermarker=False,
- safety_checker=None,
- )
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- pipeline.save_pretrained(save_path)
- del pipeline
-
-
-if __name__ == "__main__":
- main()
diff --git a/train_xl.sh b/train_xl.sh
deleted file mode 100644
index 84b52ba4e5c6661fc4ded94de359e7657deee0ff..0000000000000000000000000000000000000000
--- a/train_xl.sh
+++ /dev/null
@@ -1 +0,0 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch train_xl.py --gradient_checkpointing --use_8bit_adam --output_dir=result --train_batch_size=6 --data_dir=/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando
\ No newline at end of file
diff --git a/gradio_demo/utils_mask.py b/utils_mask.py
similarity index 100%
rename from gradio_demo/utils_mask.py
rename to utils_mask.py
diff --git a/vitonhd_test_tagged.json b/vitonhd_test_tagged.json
deleted file mode 100644
index e68c99399bdea7911b500c901ca599804a11ebeb..0000000000000000000000000000000000000000
--- a/vitonhd_test_tagged.json
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:ff196c077cf27ace24a1710fba53a2719606f60983f583aebcb61d0b3730a1dc
-size 1803684
diff --git a/vitonhd_train_tagged.json b/vitonhd_train_tagged.json
deleted file mode 100644
index 4d3ee968d6971eeae198f6fd960bfd4b99a3c92d..0000000000000000000000000000000000000000
--- a/vitonhd_train_tagged.json
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a277d878e33fa4247631ce8db5584739c7f8515183f8a476fd6e64a5f80c94f6
-size 24848224