diff --git a/.gitattributes b/.gitattributes index cad82a7d5c38d2c1723a78dc6ac0ea68ecc7bf88..a6344aac8c09253b3b630fb776ae94478aa0275b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,5 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text -*.json filter=lfs diff=lfs merge=lfs -text -*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7e99e367f8443d86e5e8825b9fda39dfbb39630d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pyc \ No newline at end of file diff --git a/README.md b/README.md index f905b22e8fbc1f5bb61bdbfcc9c5ca997c0f2a0a..d50c6a7d800d200187b6e0256028b4860b2dbb41 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,9 @@ colorFrom: blue colorTo: purple sdk: gradio sdk_version: "4.24.0" -app_file: gradio_demo/app.py +app_file: app.py pinned: false +short_description: Virtual Try-on --- # Virtual Try-On Demo @@ -18,5 +19,5 @@ This application is a Virtual Try-On model demonstration powered by Gradio. It a To start the app locally, use the following command: ```bash -python gradio_demo/app.py +python app.py ``` diff --git a/gradio_demo/app.py b/app.py similarity index 96% rename from gradio_demo/app.py rename to app.py index a25ebd10fc3b9ba636884a982d5e46856b089e39..186494d37684b0f381892c11f2e1e2ae10aeee91 100644 --- a/gradio_demo/app.py +++ b/app.py @@ -1,7 +1,6 @@ -import sys -sys.path.append('./') -from PIL import Image import gradio as gr +import spaces +from PIL import Image from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel @@ -26,7 +25,6 @@ from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation from torchvision.transforms.functional import to_pil_image -device = 'cuda:0' if torch.cuda.is_available() else 'cpu' def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) @@ -123,7 +121,9 @@ pipe = TryonPipeline.from_pretrained( ) pipe.unet_encoder = UNet_Encoder +@spaces.GPU def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed): + device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) @@ -258,10 +258,10 @@ for ex_human in human_list_path: ##default human -image_blocks = gr.Blocks().queue() +image_blocks = gr.Blocks(theme="Nymbo/Alyx_Theme").queue() with image_blocks as demo: - gr.Markdown("## IDM-VTON 👕👔👚") - gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)") + gr.HTML("

Virtual Try-On

") + gr.HTML("

Upload an image of a person and an image of a garment ✨

") with gr.Row(): with gr.Column(): imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) @@ -309,5 +309,4 @@ with image_blocks as demo: -image_blocks.launch(share=True) - +image_blocks.launch() \ No newline at end of file diff --git a/gradio_demo/apply_net.py b/apply_net.py old mode 100644 new mode 100755 similarity index 100% rename from gradio_demo/apply_net.py rename to apply_net.py diff --git a/assets/teaser.png b/assets/teaser.png deleted file mode 100644 index 931bc3b52dd73cc4763756c59393e5850d1bb97e..0000000000000000000000000000000000000000 --- a/assets/teaser.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e0ff5c96023ddf67864dc49acde2fab6a0c982fd77aa4979d9a2e77f45ad0b82 -size 7055496 diff --git a/assets/teaser2.png b/assets/teaser2.png deleted file mode 100644 index ad35f385e1cbeae96910d6c9d5ecb23084303302..0000000000000000000000000000000000000000 --- a/assets/teaser2.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4a2c3522cb7805407f437f1639418166477f334cbef739e06947b5dfc68a1968 -size 9020741 diff --git a/ckpt/humanparsing/parsing_atr.onnx b/ckpt/humanparsing/parsing_atr.onnx index 3aa341342d9ccd412b6fff79f9b0315d7dc6c3c0..28883cf4b0069c96f0f00930798428017425c3fa 100644 --- a/ckpt/humanparsing/parsing_atr.onnx +++ b/ckpt/humanparsing/parsing_atr.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3a248ff77aea7799b1a5ad036837e0b50c59b5fb95a2970e28d86b0a31b0cc73 -size 25 +oid sha256:04c7d1d070d0e0ae943d86b18cb5aaaea9e278d97462e9cfb270cbbe4cd977f4 +size 266859305 diff --git a/ckpt/humanparsing/parsing_lip.onnx b/ckpt/humanparsing/parsing_lip.onnx index 3decb0d867f96439ca0fda704626848b9ab667b3..7d1a879fa30fc002188b0c9fec3cc05064dd1093 100644 --- a/ckpt/humanparsing/parsing_lip.onnx +++ b/ckpt/humanparsing/parsing_lip.onnx @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:67f9c02ab458afa087bc02e18a9df5e479aebe48e291f46e9de6452cc3b97d37 -size 25 +oid sha256:8436e1dae96e2601c373d1ace29c8f0978b16357d9038c17a8ba756cca376dbc +size 266863411 diff --git a/ckpt/image_encoder/config.json b/ckpt/image_encoder/config.json deleted file mode 100644 index a6b2b6819e723b53e39363134706ffdca74f5179..0000000000000000000000000000000000000000 --- a/ckpt/image_encoder/config.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:625d37b31afbf2f0792a87846b3654ee23f20568409e35b78a1f795b04e1a7a1 -size 560 diff --git a/ckpt/image_encoder/model.safetensors b/ckpt/image_encoder/model.safetensors deleted file mode 100644 index 7b6f224d25646c23132235975f42511ac2920549..0000000000000000000000000000000000000000 --- a/ckpt/image_encoder/model.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ff4265338f828df6c74782f6e1614a5c01d204e7f8bfd46eafa013f3de151d0e -size 27 diff --git a/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin b/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin deleted file mode 100644 index 45555d22ef8f4eb9ecef7e461142ba39a11151a8..0000000000000000000000000000000000000000 --- a/ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e8a955fca80fb4ba5718156d2e2619c9fbcd323d7748f44c7d4ce5503763baa0 -size 24 diff --git a/ckpt/openpose/ckpts/body_pose_model.pth b/ckpt/openpose/ckpts/body_pose_model.pth index 774e6e843d5f7e3ff84eaa3f7d17b048860ddf55..9acb77e68f31906a8875f1daef2f3f7ef94acb1e 100644 --- a/ckpt/openpose/ckpts/body_pose_model.pth +++ b/ckpt/openpose/ckpts/body_pose_model.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1014e5f51f5e838d02528a05837437f0f7424a9bfb90c2d6a55f0c73db2c0e0f -size 28 +oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746 +size 209267595 diff --git a/gradio_demo/densepose/__init__.py b/densepose/__init__.py similarity index 100% rename from gradio_demo/densepose/__init__.py rename to densepose/__init__.py diff --git a/gradio_demo/densepose/config.py b/densepose/config.py similarity index 100% rename from gradio_demo/densepose/config.py rename to densepose/config.py diff --git a/gradio_demo/densepose/converters/__init__.py b/densepose/converters/__init__.py similarity index 100% rename from gradio_demo/densepose/converters/__init__.py rename to densepose/converters/__init__.py diff --git a/gradio_demo/densepose/converters/base.py b/densepose/converters/base.py similarity index 100% rename from gradio_demo/densepose/converters/base.py rename to densepose/converters/base.py diff --git a/gradio_demo/densepose/converters/builtin.py b/densepose/converters/builtin.py similarity index 100% rename from gradio_demo/densepose/converters/builtin.py rename to densepose/converters/builtin.py diff --git a/gradio_demo/densepose/converters/chart_output_hflip.py b/densepose/converters/chart_output_hflip.py similarity index 100% rename from gradio_demo/densepose/converters/chart_output_hflip.py rename to densepose/converters/chart_output_hflip.py diff --git a/gradio_demo/densepose/converters/chart_output_to_chart_result.py b/densepose/converters/chart_output_to_chart_result.py similarity index 100% rename from gradio_demo/densepose/converters/chart_output_to_chart_result.py rename to densepose/converters/chart_output_to_chart_result.py diff --git a/gradio_demo/densepose/converters/hflip.py b/densepose/converters/hflip.py similarity index 100% rename from gradio_demo/densepose/converters/hflip.py rename to densepose/converters/hflip.py diff --git a/gradio_demo/densepose/converters/segm_to_mask.py b/densepose/converters/segm_to_mask.py similarity index 100% rename from gradio_demo/densepose/converters/segm_to_mask.py rename to densepose/converters/segm_to_mask.py diff --git a/gradio_demo/densepose/converters/to_chart_result.py b/densepose/converters/to_chart_result.py similarity index 100% rename from gradio_demo/densepose/converters/to_chart_result.py rename to densepose/converters/to_chart_result.py diff --git a/gradio_demo/densepose/converters/to_mask.py b/densepose/converters/to_mask.py similarity index 100% rename from gradio_demo/densepose/converters/to_mask.py rename to densepose/converters/to_mask.py diff --git a/gradio_demo/densepose/data/__init__.py b/densepose/data/__init__.py similarity index 100% rename from gradio_demo/densepose/data/__init__.py rename to densepose/data/__init__.py diff --git a/gradio_demo/densepose/data/build.py b/densepose/data/build.py similarity index 100% rename from gradio_demo/densepose/data/build.py rename to densepose/data/build.py diff --git a/gradio_demo/densepose/data/combined_loader.py b/densepose/data/combined_loader.py similarity index 100% rename from gradio_demo/densepose/data/combined_loader.py rename to densepose/data/combined_loader.py diff --git a/gradio_demo/densepose/data/dataset_mapper.py b/densepose/data/dataset_mapper.py similarity index 100% rename from gradio_demo/densepose/data/dataset_mapper.py rename to densepose/data/dataset_mapper.py diff --git a/gradio_demo/densepose/data/datasets/__init__.py b/densepose/data/datasets/__init__.py similarity index 100% rename from gradio_demo/densepose/data/datasets/__init__.py rename to densepose/data/datasets/__init__.py diff --git a/gradio_demo/densepose/data/datasets/builtin.py b/densepose/data/datasets/builtin.py similarity index 100% rename from gradio_demo/densepose/data/datasets/builtin.py rename to densepose/data/datasets/builtin.py diff --git a/gradio_demo/densepose/data/datasets/chimpnsee.py b/densepose/data/datasets/chimpnsee.py similarity index 100% rename from gradio_demo/densepose/data/datasets/chimpnsee.py rename to densepose/data/datasets/chimpnsee.py diff --git a/gradio_demo/densepose/data/datasets/coco.py b/densepose/data/datasets/coco.py similarity index 100% rename from gradio_demo/densepose/data/datasets/coco.py rename to densepose/data/datasets/coco.py diff --git a/gradio_demo/densepose/data/datasets/dataset_type.py b/densepose/data/datasets/dataset_type.py similarity index 100% rename from gradio_demo/densepose/data/datasets/dataset_type.py rename to densepose/data/datasets/dataset_type.py diff --git a/gradio_demo/densepose/data/datasets/lvis.py b/densepose/data/datasets/lvis.py similarity index 100% rename from gradio_demo/densepose/data/datasets/lvis.py rename to densepose/data/datasets/lvis.py diff --git a/gradio_demo/densepose/data/image_list_dataset.py b/densepose/data/image_list_dataset.py similarity index 100% rename from gradio_demo/densepose/data/image_list_dataset.py rename to densepose/data/image_list_dataset.py diff --git a/gradio_demo/densepose/data/inference_based_loader.py b/densepose/data/inference_based_loader.py similarity index 100% rename from gradio_demo/densepose/data/inference_based_loader.py rename to densepose/data/inference_based_loader.py diff --git a/gradio_demo/densepose/data/meshes/__init__.py b/densepose/data/meshes/__init__.py similarity index 100% rename from gradio_demo/densepose/data/meshes/__init__.py rename to densepose/data/meshes/__init__.py diff --git a/gradio_demo/densepose/data/meshes/builtin.py b/densepose/data/meshes/builtin.py similarity index 100% rename from gradio_demo/densepose/data/meshes/builtin.py rename to densepose/data/meshes/builtin.py diff --git a/gradio_demo/densepose/data/meshes/catalog.py b/densepose/data/meshes/catalog.py similarity index 100% rename from gradio_demo/densepose/data/meshes/catalog.py rename to densepose/data/meshes/catalog.py diff --git a/gradio_demo/densepose/data/samplers/__init__.py b/densepose/data/samplers/__init__.py similarity index 100% rename from gradio_demo/densepose/data/samplers/__init__.py rename to densepose/data/samplers/__init__.py diff --git a/gradio_demo/densepose/data/samplers/densepose_base.py b/densepose/data/samplers/densepose_base.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_base.py rename to densepose/data/samplers/densepose_base.py diff --git a/gradio_demo/densepose/data/samplers/densepose_confidence_based.py b/densepose/data/samplers/densepose_confidence_based.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_confidence_based.py rename to densepose/data/samplers/densepose_confidence_based.py diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_base.py b/densepose/data/samplers/densepose_cse_base.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_cse_base.py rename to densepose/data/samplers/densepose_cse_base.py diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_confidence_based.py b/densepose/data/samplers/densepose_cse_confidence_based.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_cse_confidence_based.py rename to densepose/data/samplers/densepose_cse_confidence_based.py diff --git a/gradio_demo/densepose/data/samplers/densepose_cse_uniform.py b/densepose/data/samplers/densepose_cse_uniform.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_cse_uniform.py rename to densepose/data/samplers/densepose_cse_uniform.py diff --git a/gradio_demo/densepose/data/samplers/densepose_uniform.py b/densepose/data/samplers/densepose_uniform.py similarity index 100% rename from gradio_demo/densepose/data/samplers/densepose_uniform.py rename to densepose/data/samplers/densepose_uniform.py diff --git a/gradio_demo/densepose/data/samplers/mask_from_densepose.py b/densepose/data/samplers/mask_from_densepose.py similarity index 100% rename from gradio_demo/densepose/data/samplers/mask_from_densepose.py rename to densepose/data/samplers/mask_from_densepose.py diff --git a/gradio_demo/densepose/data/samplers/prediction_to_gt.py b/densepose/data/samplers/prediction_to_gt.py similarity index 100% rename from gradio_demo/densepose/data/samplers/prediction_to_gt.py rename to densepose/data/samplers/prediction_to_gt.py diff --git a/gradio_demo/densepose/data/transform/__init__.py b/densepose/data/transform/__init__.py similarity index 100% rename from gradio_demo/densepose/data/transform/__init__.py rename to densepose/data/transform/__init__.py diff --git a/gradio_demo/densepose/data/transform/image.py b/densepose/data/transform/image.py similarity index 100% rename from gradio_demo/densepose/data/transform/image.py rename to densepose/data/transform/image.py diff --git a/gradio_demo/densepose/data/utils.py b/densepose/data/utils.py similarity index 100% rename from gradio_demo/densepose/data/utils.py rename to densepose/data/utils.py diff --git a/gradio_demo/densepose/data/video/__init__.py b/densepose/data/video/__init__.py similarity index 100% rename from gradio_demo/densepose/data/video/__init__.py rename to densepose/data/video/__init__.py diff --git a/gradio_demo/densepose/data/video/frame_selector.py b/densepose/data/video/frame_selector.py similarity index 100% rename from gradio_demo/densepose/data/video/frame_selector.py rename to densepose/data/video/frame_selector.py diff --git a/gradio_demo/densepose/data/video/video_keyframe_dataset.py b/densepose/data/video/video_keyframe_dataset.py similarity index 100% rename from gradio_demo/densepose/data/video/video_keyframe_dataset.py rename to densepose/data/video/video_keyframe_dataset.py diff --git a/gradio_demo/densepose/engine/__init__.py b/densepose/engine/__init__.py similarity index 100% rename from gradio_demo/densepose/engine/__init__.py rename to densepose/engine/__init__.py diff --git a/gradio_demo/densepose/engine/trainer.py b/densepose/engine/trainer.py similarity index 100% rename from gradio_demo/densepose/engine/trainer.py rename to densepose/engine/trainer.py diff --git a/gradio_demo/densepose/evaluation/__init__.py b/densepose/evaluation/__init__.py similarity index 100% rename from gradio_demo/densepose/evaluation/__init__.py rename to densepose/evaluation/__init__.py diff --git a/gradio_demo/densepose/evaluation/d2_evaluator_adapter.py b/densepose/evaluation/d2_evaluator_adapter.py similarity index 100% rename from gradio_demo/densepose/evaluation/d2_evaluator_adapter.py rename to densepose/evaluation/d2_evaluator_adapter.py diff --git a/gradio_demo/densepose/evaluation/densepose_coco_evaluation.py b/densepose/evaluation/densepose_coco_evaluation.py similarity index 100% rename from gradio_demo/densepose/evaluation/densepose_coco_evaluation.py rename to densepose/evaluation/densepose_coco_evaluation.py diff --git a/gradio_demo/densepose/evaluation/evaluator.py b/densepose/evaluation/evaluator.py similarity index 100% rename from gradio_demo/densepose/evaluation/evaluator.py rename to densepose/evaluation/evaluator.py diff --git a/gradio_demo/densepose/evaluation/mesh_alignment_evaluator.py b/densepose/evaluation/mesh_alignment_evaluator.py similarity index 100% rename from gradio_demo/densepose/evaluation/mesh_alignment_evaluator.py rename to densepose/evaluation/mesh_alignment_evaluator.py diff --git a/gradio_demo/densepose/evaluation/tensor_storage.py b/densepose/evaluation/tensor_storage.py similarity index 100% rename from gradio_demo/densepose/evaluation/tensor_storage.py rename to densepose/evaluation/tensor_storage.py diff --git a/gradio_demo/densepose/modeling/__init__.py b/densepose/modeling/__init__.py similarity index 100% rename from gradio_demo/densepose/modeling/__init__.py rename to densepose/modeling/__init__.py diff --git a/gradio_demo/densepose/modeling/build.py b/densepose/modeling/build.py similarity index 100% rename from gradio_demo/densepose/modeling/build.py rename to densepose/modeling/build.py diff --git a/gradio_demo/densepose/modeling/confidence.py b/densepose/modeling/confidence.py similarity index 100% rename from gradio_demo/densepose/modeling/confidence.py rename to densepose/modeling/confidence.py diff --git a/gradio_demo/densepose/modeling/cse/__init__.py b/densepose/modeling/cse/__init__.py similarity index 100% rename from gradio_demo/densepose/modeling/cse/__init__.py rename to densepose/modeling/cse/__init__.py diff --git a/gradio_demo/densepose/modeling/cse/embedder.py b/densepose/modeling/cse/embedder.py similarity index 100% rename from gradio_demo/densepose/modeling/cse/embedder.py rename to densepose/modeling/cse/embedder.py diff --git a/gradio_demo/densepose/modeling/cse/utils.py b/densepose/modeling/cse/utils.py similarity index 100% rename from gradio_demo/densepose/modeling/cse/utils.py rename to densepose/modeling/cse/utils.py diff --git a/gradio_demo/densepose/modeling/cse/vertex_direct_embedder.py b/densepose/modeling/cse/vertex_direct_embedder.py similarity index 100% rename from gradio_demo/densepose/modeling/cse/vertex_direct_embedder.py rename to densepose/modeling/cse/vertex_direct_embedder.py diff --git a/gradio_demo/densepose/modeling/cse/vertex_feature_embedder.py b/densepose/modeling/cse/vertex_feature_embedder.py similarity index 100% rename from gradio_demo/densepose/modeling/cse/vertex_feature_embedder.py rename to densepose/modeling/cse/vertex_feature_embedder.py diff --git a/gradio_demo/densepose/modeling/densepose_checkpoint.py b/densepose/modeling/densepose_checkpoint.py similarity index 100% rename from gradio_demo/densepose/modeling/densepose_checkpoint.py rename to densepose/modeling/densepose_checkpoint.py diff --git a/gradio_demo/densepose/modeling/filter.py b/densepose/modeling/filter.py similarity index 100% rename from gradio_demo/densepose/modeling/filter.py rename to densepose/modeling/filter.py diff --git a/gradio_demo/densepose/modeling/hrfpn.py b/densepose/modeling/hrfpn.py similarity index 100% rename from gradio_demo/densepose/modeling/hrfpn.py rename to densepose/modeling/hrfpn.py diff --git a/gradio_demo/densepose/modeling/hrnet.py b/densepose/modeling/hrnet.py similarity index 100% rename from gradio_demo/densepose/modeling/hrnet.py rename to densepose/modeling/hrnet.py diff --git a/gradio_demo/densepose/modeling/inference.py b/densepose/modeling/inference.py similarity index 100% rename from gradio_demo/densepose/modeling/inference.py rename to densepose/modeling/inference.py diff --git a/gradio_demo/densepose/modeling/losses/__init__.py b/densepose/modeling/losses/__init__.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/__init__.py rename to densepose/modeling/losses/__init__.py diff --git a/gradio_demo/densepose/modeling/losses/chart.py b/densepose/modeling/losses/chart.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/chart.py rename to densepose/modeling/losses/chart.py diff --git a/gradio_demo/densepose/modeling/losses/chart_with_confidences.py b/densepose/modeling/losses/chart_with_confidences.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/chart_with_confidences.py rename to densepose/modeling/losses/chart_with_confidences.py diff --git a/gradio_demo/densepose/modeling/losses/cse.py b/densepose/modeling/losses/cse.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/cse.py rename to densepose/modeling/losses/cse.py diff --git a/gradio_demo/densepose/modeling/losses/cycle_pix2shape.py b/densepose/modeling/losses/cycle_pix2shape.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/cycle_pix2shape.py rename to densepose/modeling/losses/cycle_pix2shape.py diff --git a/gradio_demo/densepose/modeling/losses/cycle_shape2shape.py b/densepose/modeling/losses/cycle_shape2shape.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/cycle_shape2shape.py rename to densepose/modeling/losses/cycle_shape2shape.py diff --git a/gradio_demo/densepose/modeling/losses/embed.py b/densepose/modeling/losses/embed.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/embed.py rename to densepose/modeling/losses/embed.py diff --git a/gradio_demo/densepose/modeling/losses/embed_utils.py b/densepose/modeling/losses/embed_utils.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/embed_utils.py rename to densepose/modeling/losses/embed_utils.py diff --git a/gradio_demo/densepose/modeling/losses/mask.py b/densepose/modeling/losses/mask.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/mask.py rename to densepose/modeling/losses/mask.py diff --git a/gradio_demo/densepose/modeling/losses/mask_or_segm.py b/densepose/modeling/losses/mask_or_segm.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/mask_or_segm.py rename to densepose/modeling/losses/mask_or_segm.py diff --git a/gradio_demo/densepose/modeling/losses/registry.py b/densepose/modeling/losses/registry.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/registry.py rename to densepose/modeling/losses/registry.py diff --git a/gradio_demo/densepose/modeling/losses/segm.py b/densepose/modeling/losses/segm.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/segm.py rename to densepose/modeling/losses/segm.py diff --git a/gradio_demo/densepose/modeling/losses/soft_embed.py b/densepose/modeling/losses/soft_embed.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/soft_embed.py rename to densepose/modeling/losses/soft_embed.py diff --git a/gradio_demo/densepose/modeling/losses/utils.py b/densepose/modeling/losses/utils.py similarity index 100% rename from gradio_demo/densepose/modeling/losses/utils.py rename to densepose/modeling/losses/utils.py diff --git a/gradio_demo/densepose/modeling/predictors/__init__.py b/densepose/modeling/predictors/__init__.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/__init__.py rename to densepose/modeling/predictors/__init__.py diff --git a/gradio_demo/densepose/modeling/predictors/chart.py b/densepose/modeling/predictors/chart.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/chart.py rename to densepose/modeling/predictors/chart.py diff --git a/gradio_demo/densepose/modeling/predictors/chart_confidence.py b/densepose/modeling/predictors/chart_confidence.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/chart_confidence.py rename to densepose/modeling/predictors/chart_confidence.py diff --git a/gradio_demo/densepose/modeling/predictors/chart_with_confidence.py b/densepose/modeling/predictors/chart_with_confidence.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/chart_with_confidence.py rename to densepose/modeling/predictors/chart_with_confidence.py diff --git a/gradio_demo/densepose/modeling/predictors/cse.py b/densepose/modeling/predictors/cse.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/cse.py rename to densepose/modeling/predictors/cse.py diff --git a/gradio_demo/densepose/modeling/predictors/cse_confidence.py b/densepose/modeling/predictors/cse_confidence.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/cse_confidence.py rename to densepose/modeling/predictors/cse_confidence.py diff --git a/gradio_demo/densepose/modeling/predictors/cse_with_confidence.py b/densepose/modeling/predictors/cse_with_confidence.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/cse_with_confidence.py rename to densepose/modeling/predictors/cse_with_confidence.py diff --git a/gradio_demo/densepose/modeling/predictors/registry.py b/densepose/modeling/predictors/registry.py similarity index 100% rename from gradio_demo/densepose/modeling/predictors/registry.py rename to densepose/modeling/predictors/registry.py diff --git a/gradio_demo/densepose/modeling/roi_heads/__init__.py b/densepose/modeling/roi_heads/__init__.py similarity index 100% rename from gradio_demo/densepose/modeling/roi_heads/__init__.py rename to densepose/modeling/roi_heads/__init__.py diff --git a/gradio_demo/densepose/modeling/roi_heads/deeplab.py b/densepose/modeling/roi_heads/deeplab.py similarity index 100% rename from gradio_demo/densepose/modeling/roi_heads/deeplab.py rename to densepose/modeling/roi_heads/deeplab.py diff --git a/gradio_demo/densepose/modeling/roi_heads/registry.py b/densepose/modeling/roi_heads/registry.py similarity index 100% rename from gradio_demo/densepose/modeling/roi_heads/registry.py rename to densepose/modeling/roi_heads/registry.py diff --git a/gradio_demo/densepose/modeling/roi_heads/roi_head.py b/densepose/modeling/roi_heads/roi_head.py similarity index 100% rename from gradio_demo/densepose/modeling/roi_heads/roi_head.py rename to densepose/modeling/roi_heads/roi_head.py diff --git a/gradio_demo/densepose/modeling/roi_heads/v1convx.py b/densepose/modeling/roi_heads/v1convx.py similarity index 100% rename from gradio_demo/densepose/modeling/roi_heads/v1convx.py rename to densepose/modeling/roi_heads/v1convx.py diff --git a/gradio_demo/densepose/modeling/test_time_augmentation.py b/densepose/modeling/test_time_augmentation.py similarity index 100% rename from gradio_demo/densepose/modeling/test_time_augmentation.py rename to densepose/modeling/test_time_augmentation.py diff --git a/gradio_demo/densepose/modeling/utils.py b/densepose/modeling/utils.py similarity index 100% rename from gradio_demo/densepose/modeling/utils.py rename to densepose/modeling/utils.py diff --git a/gradio_demo/densepose/structures/__init__.py b/densepose/structures/__init__.py similarity index 100% rename from gradio_demo/densepose/structures/__init__.py rename to densepose/structures/__init__.py diff --git a/gradio_demo/densepose/structures/chart.py b/densepose/structures/chart.py similarity index 100% rename from gradio_demo/densepose/structures/chart.py rename to densepose/structures/chart.py diff --git a/gradio_demo/densepose/structures/chart_confidence.py b/densepose/structures/chart_confidence.py similarity index 100% rename from gradio_demo/densepose/structures/chart_confidence.py rename to densepose/structures/chart_confidence.py diff --git a/gradio_demo/densepose/structures/chart_result.py b/densepose/structures/chart_result.py similarity index 100% rename from gradio_demo/densepose/structures/chart_result.py rename to densepose/structures/chart_result.py diff --git a/gradio_demo/densepose/structures/cse.py b/densepose/structures/cse.py similarity index 100% rename from gradio_demo/densepose/structures/cse.py rename to densepose/structures/cse.py diff --git a/gradio_demo/densepose/structures/cse_confidence.py b/densepose/structures/cse_confidence.py similarity index 100% rename from gradio_demo/densepose/structures/cse_confidence.py rename to densepose/structures/cse_confidence.py diff --git a/gradio_demo/densepose/structures/data_relative.py b/densepose/structures/data_relative.py similarity index 100% rename from gradio_demo/densepose/structures/data_relative.py rename to densepose/structures/data_relative.py diff --git a/gradio_demo/densepose/structures/list.py b/densepose/structures/list.py similarity index 100% rename from gradio_demo/densepose/structures/list.py rename to densepose/structures/list.py diff --git a/gradio_demo/densepose/structures/mesh.py b/densepose/structures/mesh.py similarity index 100% rename from gradio_demo/densepose/structures/mesh.py rename to densepose/structures/mesh.py diff --git a/gradio_demo/densepose/structures/transform_data.py b/densepose/structures/transform_data.py similarity index 100% rename from gradio_demo/densepose/structures/transform_data.py rename to densepose/structures/transform_data.py diff --git a/gradio_demo/densepose/utils/__init__.py b/densepose/utils/__init__.py similarity index 100% rename from gradio_demo/densepose/utils/__init__.py rename to densepose/utils/__init__.py diff --git a/gradio_demo/densepose/utils/dbhelper.py b/densepose/utils/dbhelper.py similarity index 100% rename from gradio_demo/densepose/utils/dbhelper.py rename to densepose/utils/dbhelper.py diff --git a/gradio_demo/densepose/utils/logger.py b/densepose/utils/logger.py similarity index 100% rename from gradio_demo/densepose/utils/logger.py rename to densepose/utils/logger.py diff --git a/gradio_demo/densepose/utils/transform.py b/densepose/utils/transform.py similarity index 100% rename from gradio_demo/densepose/utils/transform.py rename to densepose/utils/transform.py diff --git a/gradio_demo/densepose/vis/__init__.py b/densepose/vis/__init__.py similarity index 100% rename from gradio_demo/densepose/vis/__init__.py rename to densepose/vis/__init__.py diff --git a/gradio_demo/densepose/vis/base.py b/densepose/vis/base.py similarity index 100% rename from gradio_demo/densepose/vis/base.py rename to densepose/vis/base.py diff --git a/gradio_demo/densepose/vis/bounding_box.py b/densepose/vis/bounding_box.py similarity index 100% rename from gradio_demo/densepose/vis/bounding_box.py rename to densepose/vis/bounding_box.py diff --git a/gradio_demo/densepose/vis/densepose_data_points.py b/densepose/vis/densepose_data_points.py similarity index 100% rename from gradio_demo/densepose/vis/densepose_data_points.py rename to densepose/vis/densepose_data_points.py diff --git a/gradio_demo/densepose/vis/densepose_outputs_iuv.py b/densepose/vis/densepose_outputs_iuv.py similarity index 100% rename from gradio_demo/densepose/vis/densepose_outputs_iuv.py rename to densepose/vis/densepose_outputs_iuv.py diff --git a/gradio_demo/densepose/vis/densepose_outputs_vertex.py b/densepose/vis/densepose_outputs_vertex.py similarity index 100% rename from gradio_demo/densepose/vis/densepose_outputs_vertex.py rename to densepose/vis/densepose_outputs_vertex.py diff --git a/gradio_demo/densepose/vis/densepose_results.py b/densepose/vis/densepose_results.py similarity index 100% rename from gradio_demo/densepose/vis/densepose_results.py rename to densepose/vis/densepose_results.py diff --git a/gradio_demo/densepose/vis/densepose_results_textures.py b/densepose/vis/densepose_results_textures.py similarity index 100% rename from gradio_demo/densepose/vis/densepose_results_textures.py rename to densepose/vis/densepose_results_textures.py diff --git a/gradio_demo/densepose/vis/extractor.py b/densepose/vis/extractor.py similarity index 100% rename from gradio_demo/densepose/vis/extractor.py rename to densepose/vis/extractor.py diff --git a/gradio_demo/detectron2/_C.cpython-39-x86_64-linux-gnu.so b/detectron2/_C.cpython-39-x86_64-linux-gnu.so old mode 100644 new mode 100755 similarity index 100% rename from gradio_demo/detectron2/_C.cpython-39-x86_64-linux-gnu.so rename to detectron2/_C.cpython-39-x86_64-linux-gnu.so diff --git a/gradio_demo/detectron2/__init__.py b/detectron2/__init__.py similarity index 100% rename from gradio_demo/detectron2/__init__.py rename to detectron2/__init__.py diff --git a/gradio_demo/detectron2/checkpoint/__init__.py b/detectron2/checkpoint/__init__.py similarity index 100% rename from gradio_demo/detectron2/checkpoint/__init__.py rename to detectron2/checkpoint/__init__.py diff --git a/gradio_demo/detectron2/checkpoint/c2_model_loading.py b/detectron2/checkpoint/c2_model_loading.py similarity index 100% rename from gradio_demo/detectron2/checkpoint/c2_model_loading.py rename to detectron2/checkpoint/c2_model_loading.py diff --git a/gradio_demo/detectron2/checkpoint/catalog.py b/detectron2/checkpoint/catalog.py similarity index 100% rename from gradio_demo/detectron2/checkpoint/catalog.py rename to detectron2/checkpoint/catalog.py diff --git a/gradio_demo/detectron2/checkpoint/detection_checkpoint.py b/detectron2/checkpoint/detection_checkpoint.py similarity index 100% rename from gradio_demo/detectron2/checkpoint/detection_checkpoint.py rename to detectron2/checkpoint/detection_checkpoint.py diff --git a/gradio_demo/detectron2/config/__init__.py b/detectron2/config/__init__.py similarity index 100% rename from gradio_demo/detectron2/config/__init__.py rename to detectron2/config/__init__.py diff --git a/gradio_demo/detectron2/config/compat.py b/detectron2/config/compat.py similarity index 100% rename from gradio_demo/detectron2/config/compat.py rename to detectron2/config/compat.py diff --git a/gradio_demo/detectron2/config/config.py b/detectron2/config/config.py similarity index 100% rename from gradio_demo/detectron2/config/config.py rename to detectron2/config/config.py diff --git a/gradio_demo/detectron2/config/defaults.py b/detectron2/config/defaults.py similarity index 100% rename from gradio_demo/detectron2/config/defaults.py rename to detectron2/config/defaults.py diff --git a/gradio_demo/detectron2/config/instantiate.py b/detectron2/config/instantiate.py similarity index 100% rename from gradio_demo/detectron2/config/instantiate.py rename to detectron2/config/instantiate.py diff --git a/gradio_demo/detectron2/config/lazy.py b/detectron2/config/lazy.py similarity index 100% rename from gradio_demo/detectron2/config/lazy.py rename to detectron2/config/lazy.py diff --git a/gradio_demo/detectron2/data/__init__.py b/detectron2/data/__init__.py similarity index 100% rename from gradio_demo/detectron2/data/__init__.py rename to detectron2/data/__init__.py diff --git a/gradio_demo/detectron2/data/benchmark.py b/detectron2/data/benchmark.py similarity index 100% rename from gradio_demo/detectron2/data/benchmark.py rename to detectron2/data/benchmark.py diff --git a/gradio_demo/detectron2/data/build.py b/detectron2/data/build.py similarity index 100% rename from gradio_demo/detectron2/data/build.py rename to detectron2/data/build.py diff --git a/gradio_demo/detectron2/data/catalog.py b/detectron2/data/catalog.py similarity index 100% rename from gradio_demo/detectron2/data/catalog.py rename to detectron2/data/catalog.py diff --git a/gradio_demo/detectron2/data/common.py b/detectron2/data/common.py similarity index 100% rename from gradio_demo/detectron2/data/common.py rename to detectron2/data/common.py diff --git a/gradio_demo/detectron2/data/dataset_mapper.py b/detectron2/data/dataset_mapper.py similarity index 100% rename from gradio_demo/detectron2/data/dataset_mapper.py rename to detectron2/data/dataset_mapper.py diff --git a/gradio_demo/detectron2/data/datasets/README.md b/detectron2/data/datasets/README.md similarity index 100% rename from gradio_demo/detectron2/data/datasets/README.md rename to detectron2/data/datasets/README.md diff --git a/gradio_demo/detectron2/data/datasets/__init__.py b/detectron2/data/datasets/__init__.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/__init__.py rename to detectron2/data/datasets/__init__.py diff --git a/gradio_demo/detectron2/data/datasets/builtin.py b/detectron2/data/datasets/builtin.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/builtin.py rename to detectron2/data/datasets/builtin.py diff --git a/gradio_demo/detectron2/data/datasets/builtin_meta.py b/detectron2/data/datasets/builtin_meta.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/builtin_meta.py rename to detectron2/data/datasets/builtin_meta.py diff --git a/gradio_demo/detectron2/data/datasets/cityscapes.py b/detectron2/data/datasets/cityscapes.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/cityscapes.py rename to detectron2/data/datasets/cityscapes.py diff --git a/gradio_demo/detectron2/data/datasets/cityscapes_panoptic.py b/detectron2/data/datasets/cityscapes_panoptic.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/cityscapes_panoptic.py rename to detectron2/data/datasets/cityscapes_panoptic.py diff --git a/gradio_demo/detectron2/data/datasets/coco.py b/detectron2/data/datasets/coco.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/coco.py rename to detectron2/data/datasets/coco.py diff --git a/gradio_demo/detectron2/data/datasets/coco_panoptic.py b/detectron2/data/datasets/coco_panoptic.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/coco_panoptic.py rename to detectron2/data/datasets/coco_panoptic.py diff --git a/gradio_demo/detectron2/data/datasets/lvis.py b/detectron2/data/datasets/lvis.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/lvis.py rename to detectron2/data/datasets/lvis.py diff --git a/gradio_demo/detectron2/data/datasets/lvis_v0_5_categories.py b/detectron2/data/datasets/lvis_v0_5_categories.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/lvis_v0_5_categories.py rename to detectron2/data/datasets/lvis_v0_5_categories.py diff --git a/gradio_demo/detectron2/data/datasets/lvis_v1_categories.py b/detectron2/data/datasets/lvis_v1_categories.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/lvis_v1_categories.py rename to detectron2/data/datasets/lvis_v1_categories.py diff --git a/gradio_demo/detectron2/data/datasets/lvis_v1_category_image_count.py b/detectron2/data/datasets/lvis_v1_category_image_count.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/lvis_v1_category_image_count.py rename to detectron2/data/datasets/lvis_v1_category_image_count.py diff --git a/gradio_demo/detectron2/data/datasets/pascal_voc.py b/detectron2/data/datasets/pascal_voc.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/pascal_voc.py rename to detectron2/data/datasets/pascal_voc.py diff --git a/gradio_demo/detectron2/data/datasets/register_coco.py b/detectron2/data/datasets/register_coco.py similarity index 100% rename from gradio_demo/detectron2/data/datasets/register_coco.py rename to detectron2/data/datasets/register_coco.py diff --git a/gradio_demo/detectron2/data/detection_utils.py b/detectron2/data/detection_utils.py similarity index 100% rename from gradio_demo/detectron2/data/detection_utils.py rename to detectron2/data/detection_utils.py diff --git a/gradio_demo/detectron2/data/samplers/__init__.py b/detectron2/data/samplers/__init__.py similarity index 100% rename from gradio_demo/detectron2/data/samplers/__init__.py rename to detectron2/data/samplers/__init__.py diff --git a/gradio_demo/detectron2/data/samplers/distributed_sampler.py b/detectron2/data/samplers/distributed_sampler.py similarity index 100% rename from gradio_demo/detectron2/data/samplers/distributed_sampler.py rename to detectron2/data/samplers/distributed_sampler.py diff --git a/gradio_demo/detectron2/data/samplers/grouped_batch_sampler.py b/detectron2/data/samplers/grouped_batch_sampler.py similarity index 100% rename from gradio_demo/detectron2/data/samplers/grouped_batch_sampler.py rename to detectron2/data/samplers/grouped_batch_sampler.py diff --git a/gradio_demo/detectron2/data/transforms/__init__.py b/detectron2/data/transforms/__init__.py similarity index 100% rename from gradio_demo/detectron2/data/transforms/__init__.py rename to detectron2/data/transforms/__init__.py diff --git a/gradio_demo/detectron2/data/transforms/augmentation.py b/detectron2/data/transforms/augmentation.py similarity index 100% rename from gradio_demo/detectron2/data/transforms/augmentation.py rename to detectron2/data/transforms/augmentation.py diff --git a/gradio_demo/detectron2/data/transforms/augmentation_impl.py b/detectron2/data/transforms/augmentation_impl.py similarity index 100% rename from gradio_demo/detectron2/data/transforms/augmentation_impl.py rename to detectron2/data/transforms/augmentation_impl.py diff --git a/gradio_demo/detectron2/data/transforms/transform.py b/detectron2/data/transforms/transform.py similarity index 100% rename from gradio_demo/detectron2/data/transforms/transform.py rename to detectron2/data/transforms/transform.py diff --git a/gradio_demo/detectron2/engine/__init__.py b/detectron2/engine/__init__.py similarity index 100% rename from gradio_demo/detectron2/engine/__init__.py rename to detectron2/engine/__init__.py diff --git a/gradio_demo/detectron2/engine/defaults.py b/detectron2/engine/defaults.py similarity index 100% rename from gradio_demo/detectron2/engine/defaults.py rename to detectron2/engine/defaults.py diff --git a/gradio_demo/detectron2/engine/hooks.py b/detectron2/engine/hooks.py similarity index 100% rename from gradio_demo/detectron2/engine/hooks.py rename to detectron2/engine/hooks.py diff --git a/gradio_demo/detectron2/engine/launch.py b/detectron2/engine/launch.py similarity index 100% rename from gradio_demo/detectron2/engine/launch.py rename to detectron2/engine/launch.py diff --git a/gradio_demo/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py similarity index 100% rename from gradio_demo/detectron2/engine/train_loop.py rename to detectron2/engine/train_loop.py diff --git a/gradio_demo/detectron2/evaluation/__init__.py b/detectron2/evaluation/__init__.py similarity index 100% rename from gradio_demo/detectron2/evaluation/__init__.py rename to detectron2/evaluation/__init__.py diff --git a/gradio_demo/detectron2/evaluation/cityscapes_evaluation.py b/detectron2/evaluation/cityscapes_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/cityscapes_evaluation.py rename to detectron2/evaluation/cityscapes_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/coco_evaluation.py b/detectron2/evaluation/coco_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/coco_evaluation.py rename to detectron2/evaluation/coco_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/evaluator.py b/detectron2/evaluation/evaluator.py similarity index 100% rename from gradio_demo/detectron2/evaluation/evaluator.py rename to detectron2/evaluation/evaluator.py diff --git a/gradio_demo/detectron2/evaluation/fast_eval_api.py b/detectron2/evaluation/fast_eval_api.py similarity index 100% rename from gradio_demo/detectron2/evaluation/fast_eval_api.py rename to detectron2/evaluation/fast_eval_api.py diff --git a/gradio_demo/detectron2/evaluation/lvis_evaluation.py b/detectron2/evaluation/lvis_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/lvis_evaluation.py rename to detectron2/evaluation/lvis_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/panoptic_evaluation.py b/detectron2/evaluation/panoptic_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/panoptic_evaluation.py rename to detectron2/evaluation/panoptic_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/pascal_voc_evaluation.py b/detectron2/evaluation/pascal_voc_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/pascal_voc_evaluation.py rename to detectron2/evaluation/pascal_voc_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/rotated_coco_evaluation.py b/detectron2/evaluation/rotated_coco_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/rotated_coco_evaluation.py rename to detectron2/evaluation/rotated_coco_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/sem_seg_evaluation.py b/detectron2/evaluation/sem_seg_evaluation.py similarity index 100% rename from gradio_demo/detectron2/evaluation/sem_seg_evaluation.py rename to detectron2/evaluation/sem_seg_evaluation.py diff --git a/gradio_demo/detectron2/evaluation/testing.py b/detectron2/evaluation/testing.py similarity index 100% rename from gradio_demo/detectron2/evaluation/testing.py rename to detectron2/evaluation/testing.py diff --git a/gradio_demo/detectron2/export/README.md b/detectron2/export/README.md similarity index 100% rename from gradio_demo/detectron2/export/README.md rename to detectron2/export/README.md diff --git a/gradio_demo/detectron2/export/__init__.py b/detectron2/export/__init__.py similarity index 100% rename from gradio_demo/detectron2/export/__init__.py rename to detectron2/export/__init__.py diff --git a/gradio_demo/detectron2/export/api.py b/detectron2/export/api.py similarity index 100% rename from gradio_demo/detectron2/export/api.py rename to detectron2/export/api.py diff --git a/gradio_demo/detectron2/export/c10.py b/detectron2/export/c10.py similarity index 100% rename from gradio_demo/detectron2/export/c10.py rename to detectron2/export/c10.py diff --git a/gradio_demo/detectron2/export/caffe2_export.py b/detectron2/export/caffe2_export.py similarity index 100% rename from gradio_demo/detectron2/export/caffe2_export.py rename to detectron2/export/caffe2_export.py diff --git a/gradio_demo/detectron2/export/caffe2_inference.py b/detectron2/export/caffe2_inference.py similarity index 100% rename from gradio_demo/detectron2/export/caffe2_inference.py rename to detectron2/export/caffe2_inference.py diff --git a/gradio_demo/detectron2/export/caffe2_modeling.py b/detectron2/export/caffe2_modeling.py similarity index 100% rename from gradio_demo/detectron2/export/caffe2_modeling.py rename to detectron2/export/caffe2_modeling.py diff --git a/gradio_demo/detectron2/export/caffe2_patch.py b/detectron2/export/caffe2_patch.py similarity index 100% rename from gradio_demo/detectron2/export/caffe2_patch.py rename to detectron2/export/caffe2_patch.py diff --git a/gradio_demo/detectron2/export/flatten.py b/detectron2/export/flatten.py similarity index 100% rename from gradio_demo/detectron2/export/flatten.py rename to detectron2/export/flatten.py diff --git a/gradio_demo/detectron2/export/shared.py b/detectron2/export/shared.py similarity index 100% rename from gradio_demo/detectron2/export/shared.py rename to detectron2/export/shared.py diff --git a/gradio_demo/detectron2/export/torchscript.py b/detectron2/export/torchscript.py similarity index 100% rename from gradio_demo/detectron2/export/torchscript.py rename to detectron2/export/torchscript.py diff --git a/gradio_demo/detectron2/export/torchscript_patch.py b/detectron2/export/torchscript_patch.py similarity index 100% rename from gradio_demo/detectron2/export/torchscript_patch.py rename to detectron2/export/torchscript_patch.py diff --git a/gradio_demo/detectron2/layers/__init__.py b/detectron2/layers/__init__.py similarity index 100% rename from gradio_demo/detectron2/layers/__init__.py rename to detectron2/layers/__init__.py diff --git a/gradio_demo/detectron2/layers/aspp.py b/detectron2/layers/aspp.py similarity index 100% rename from gradio_demo/detectron2/layers/aspp.py rename to detectron2/layers/aspp.py diff --git a/gradio_demo/detectron2/layers/batch_norm.py b/detectron2/layers/batch_norm.py similarity index 100% rename from gradio_demo/detectron2/layers/batch_norm.py rename to detectron2/layers/batch_norm.py diff --git a/gradio_demo/detectron2/layers/blocks.py b/detectron2/layers/blocks.py similarity index 100% rename from gradio_demo/detectron2/layers/blocks.py rename to detectron2/layers/blocks.py diff --git a/gradio_demo/detectron2/layers/csrc/README.md b/detectron2/layers/csrc/README.md similarity index 100% rename from gradio_demo/detectron2/layers/csrc/README.md rename to detectron2/layers/csrc/README.md diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp similarity index 100% rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp diff --git a/gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu b/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu rename to detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp similarity index 100% rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu diff --git a/gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h b/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h rename to detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h diff --git a/gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.cpp b/detectron2/layers/csrc/cocoeval/cocoeval.cpp similarity index 100% rename from gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.cpp rename to detectron2/layers/csrc/cocoeval/cocoeval.cpp diff --git a/gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.h b/detectron2/layers/csrc/cocoeval/cocoeval.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/cocoeval/cocoeval.h rename to detectron2/layers/csrc/cocoeval/cocoeval.h diff --git a/gradio_demo/detectron2/layers/csrc/cuda_version.cu b/detectron2/layers/csrc/cuda_version.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/cuda_version.cu rename to detectron2/layers/csrc/cuda_version.cu diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv.h b/detectron2/layers/csrc/deformable/deform_conv.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv.h rename to detectron2/layers/csrc/deformable/deform_conv.h diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda.cu b/detectron2/layers/csrc/deformable/deform_conv_cuda.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda.cu rename to detectron2/layers/csrc/deformable/deform_conv_cuda.cu diff --git a/gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu b/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu rename to detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated.h b/detectron2/layers/csrc/nms_rotated/nms_rotated.h similarity index 100% rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated.h rename to detectron2/layers/csrc/nms_rotated/nms_rotated.h diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp b/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp similarity index 100% rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp rename to detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp diff --git a/gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu b/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu similarity index 100% rename from gradio_demo/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu rename to detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu diff --git a/gradio_demo/detectron2/layers/csrc/vision.cpp b/detectron2/layers/csrc/vision.cpp similarity index 100% rename from gradio_demo/detectron2/layers/csrc/vision.cpp rename to detectron2/layers/csrc/vision.cpp diff --git a/gradio_demo/detectron2/layers/deform_conv.py b/detectron2/layers/deform_conv.py similarity index 100% rename from gradio_demo/detectron2/layers/deform_conv.py rename to detectron2/layers/deform_conv.py diff --git a/gradio_demo/detectron2/layers/losses.py b/detectron2/layers/losses.py similarity index 100% rename from gradio_demo/detectron2/layers/losses.py rename to detectron2/layers/losses.py diff --git a/gradio_demo/detectron2/layers/mask_ops.py b/detectron2/layers/mask_ops.py similarity index 100% rename from gradio_demo/detectron2/layers/mask_ops.py rename to detectron2/layers/mask_ops.py diff --git a/gradio_demo/detectron2/layers/nms.py b/detectron2/layers/nms.py similarity index 100% rename from gradio_demo/detectron2/layers/nms.py rename to detectron2/layers/nms.py diff --git a/gradio_demo/detectron2/layers/roi_align.py b/detectron2/layers/roi_align.py similarity index 100% rename from gradio_demo/detectron2/layers/roi_align.py rename to detectron2/layers/roi_align.py diff --git a/gradio_demo/detectron2/layers/roi_align_rotated.py b/detectron2/layers/roi_align_rotated.py similarity index 100% rename from gradio_demo/detectron2/layers/roi_align_rotated.py rename to detectron2/layers/roi_align_rotated.py diff --git a/gradio_demo/detectron2/layers/rotated_boxes.py b/detectron2/layers/rotated_boxes.py similarity index 100% rename from gradio_demo/detectron2/layers/rotated_boxes.py rename to detectron2/layers/rotated_boxes.py diff --git a/gradio_demo/detectron2/layers/shape_spec.py b/detectron2/layers/shape_spec.py similarity index 100% rename from gradio_demo/detectron2/layers/shape_spec.py rename to detectron2/layers/shape_spec.py diff --git a/gradio_demo/detectron2/layers/wrappers.py b/detectron2/layers/wrappers.py similarity index 100% rename from gradio_demo/detectron2/layers/wrappers.py rename to detectron2/layers/wrappers.py diff --git a/gradio_demo/detectron2/model_zoo/__init__.py b/detectron2/model_zoo/__init__.py similarity index 100% rename from gradio_demo/detectron2/model_zoo/__init__.py rename to detectron2/model_zoo/__init__.py diff --git a/gradio_demo/detectron2/model_zoo/model_zoo.py b/detectron2/model_zoo/model_zoo.py similarity index 100% rename from gradio_demo/detectron2/model_zoo/model_zoo.py rename to detectron2/model_zoo/model_zoo.py diff --git a/gradio_demo/detectron2/modeling/__init__.py b/detectron2/modeling/__init__.py similarity index 100% rename from gradio_demo/detectron2/modeling/__init__.py rename to detectron2/modeling/__init__.py diff --git a/gradio_demo/detectron2/modeling/anchor_generator.py b/detectron2/modeling/anchor_generator.py similarity index 100% rename from gradio_demo/detectron2/modeling/anchor_generator.py rename to detectron2/modeling/anchor_generator.py diff --git a/gradio_demo/detectron2/modeling/backbone/__init__.py b/detectron2/modeling/backbone/__init__.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/__init__.py rename to detectron2/modeling/backbone/__init__.py diff --git a/gradio_demo/detectron2/modeling/backbone/backbone.py b/detectron2/modeling/backbone/backbone.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/backbone.py rename to detectron2/modeling/backbone/backbone.py diff --git a/gradio_demo/detectron2/modeling/backbone/build.py b/detectron2/modeling/backbone/build.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/build.py rename to detectron2/modeling/backbone/build.py diff --git a/gradio_demo/detectron2/modeling/backbone/fpn.py b/detectron2/modeling/backbone/fpn.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/fpn.py rename to detectron2/modeling/backbone/fpn.py diff --git a/gradio_demo/detectron2/modeling/backbone/mvit.py b/detectron2/modeling/backbone/mvit.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/mvit.py rename to detectron2/modeling/backbone/mvit.py diff --git a/gradio_demo/detectron2/modeling/backbone/regnet.py b/detectron2/modeling/backbone/regnet.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/regnet.py rename to detectron2/modeling/backbone/regnet.py diff --git a/gradio_demo/detectron2/modeling/backbone/resnet.py b/detectron2/modeling/backbone/resnet.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/resnet.py rename to detectron2/modeling/backbone/resnet.py diff --git a/gradio_demo/detectron2/modeling/backbone/swin.py b/detectron2/modeling/backbone/swin.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/swin.py rename to detectron2/modeling/backbone/swin.py diff --git a/gradio_demo/detectron2/modeling/backbone/utils.py b/detectron2/modeling/backbone/utils.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/utils.py rename to detectron2/modeling/backbone/utils.py diff --git a/gradio_demo/detectron2/modeling/backbone/vit.py b/detectron2/modeling/backbone/vit.py similarity index 100% rename from gradio_demo/detectron2/modeling/backbone/vit.py rename to detectron2/modeling/backbone/vit.py diff --git a/gradio_demo/detectron2/modeling/box_regression.py b/detectron2/modeling/box_regression.py similarity index 100% rename from gradio_demo/detectron2/modeling/box_regression.py rename to detectron2/modeling/box_regression.py diff --git a/gradio_demo/detectron2/modeling/matcher.py b/detectron2/modeling/matcher.py similarity index 100% rename from gradio_demo/detectron2/modeling/matcher.py rename to detectron2/modeling/matcher.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/__init__.py b/detectron2/modeling/meta_arch/__init__.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/__init__.py rename to detectron2/modeling/meta_arch/__init__.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/build.py b/detectron2/modeling/meta_arch/build.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/build.py rename to detectron2/modeling/meta_arch/build.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/dense_detector.py b/detectron2/modeling/meta_arch/dense_detector.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/dense_detector.py rename to detectron2/modeling/meta_arch/dense_detector.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/fcos.py b/detectron2/modeling/meta_arch/fcos.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/fcos.py rename to detectron2/modeling/meta_arch/fcos.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/panoptic_fpn.py b/detectron2/modeling/meta_arch/panoptic_fpn.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/panoptic_fpn.py rename to detectron2/modeling/meta_arch/panoptic_fpn.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/rcnn.py rename to detectron2/modeling/meta_arch/rcnn.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/retinanet.py b/detectron2/modeling/meta_arch/retinanet.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/retinanet.py rename to detectron2/modeling/meta_arch/retinanet.py diff --git a/gradio_demo/detectron2/modeling/meta_arch/semantic_seg.py b/detectron2/modeling/meta_arch/semantic_seg.py similarity index 100% rename from gradio_demo/detectron2/modeling/meta_arch/semantic_seg.py rename to detectron2/modeling/meta_arch/semantic_seg.py diff --git a/gradio_demo/detectron2/modeling/mmdet_wrapper.py b/detectron2/modeling/mmdet_wrapper.py similarity index 100% rename from gradio_demo/detectron2/modeling/mmdet_wrapper.py rename to detectron2/modeling/mmdet_wrapper.py diff --git a/gradio_demo/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py similarity index 100% rename from gradio_demo/detectron2/modeling/poolers.py rename to detectron2/modeling/poolers.py diff --git a/gradio_demo/detectron2/modeling/postprocessing.py b/detectron2/modeling/postprocessing.py similarity index 100% rename from gradio_demo/detectron2/modeling/postprocessing.py rename to detectron2/modeling/postprocessing.py diff --git a/gradio_demo/detectron2/modeling/proposal_generator/__init__.py b/detectron2/modeling/proposal_generator/__init__.py similarity index 100% rename from gradio_demo/detectron2/modeling/proposal_generator/__init__.py rename to detectron2/modeling/proposal_generator/__init__.py diff --git a/gradio_demo/detectron2/modeling/proposal_generator/build.py b/detectron2/modeling/proposal_generator/build.py similarity index 100% rename from gradio_demo/detectron2/modeling/proposal_generator/build.py rename to detectron2/modeling/proposal_generator/build.py diff --git a/gradio_demo/detectron2/modeling/proposal_generator/proposal_utils.py b/detectron2/modeling/proposal_generator/proposal_utils.py similarity index 100% rename from gradio_demo/detectron2/modeling/proposal_generator/proposal_utils.py rename to detectron2/modeling/proposal_generator/proposal_utils.py diff --git a/gradio_demo/detectron2/modeling/proposal_generator/rpn.py b/detectron2/modeling/proposal_generator/rpn.py similarity index 100% rename from gradio_demo/detectron2/modeling/proposal_generator/rpn.py rename to detectron2/modeling/proposal_generator/rpn.py diff --git a/gradio_demo/detectron2/modeling/proposal_generator/rrpn.py b/detectron2/modeling/proposal_generator/rrpn.py similarity index 100% rename from gradio_demo/detectron2/modeling/proposal_generator/rrpn.py rename to detectron2/modeling/proposal_generator/rrpn.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/__init__.py b/detectron2/modeling/roi_heads/__init__.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/__init__.py rename to detectron2/modeling/roi_heads/__init__.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/box_head.py b/detectron2/modeling/roi_heads/box_head.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/box_head.py rename to detectron2/modeling/roi_heads/box_head.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/cascade_rcnn.py b/detectron2/modeling/roi_heads/cascade_rcnn.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/cascade_rcnn.py rename to detectron2/modeling/roi_heads/cascade_rcnn.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/fast_rcnn.py rename to detectron2/modeling/roi_heads/fast_rcnn.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/keypoint_head.py b/detectron2/modeling/roi_heads/keypoint_head.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/keypoint_head.py rename to detectron2/modeling/roi_heads/keypoint_head.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/mask_head.py b/detectron2/modeling/roi_heads/mask_head.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/mask_head.py rename to detectron2/modeling/roi_heads/mask_head.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/roi_heads.py b/detectron2/modeling/roi_heads/roi_heads.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/roi_heads.py rename to detectron2/modeling/roi_heads/roi_heads.py diff --git a/gradio_demo/detectron2/modeling/roi_heads/rotated_fast_rcnn.py b/detectron2/modeling/roi_heads/rotated_fast_rcnn.py similarity index 100% rename from gradio_demo/detectron2/modeling/roi_heads/rotated_fast_rcnn.py rename to detectron2/modeling/roi_heads/rotated_fast_rcnn.py diff --git a/gradio_demo/detectron2/modeling/sampling.py b/detectron2/modeling/sampling.py similarity index 100% rename from gradio_demo/detectron2/modeling/sampling.py rename to detectron2/modeling/sampling.py diff --git a/gradio_demo/detectron2/modeling/test_time_augmentation.py b/detectron2/modeling/test_time_augmentation.py similarity index 100% rename from gradio_demo/detectron2/modeling/test_time_augmentation.py rename to detectron2/modeling/test_time_augmentation.py diff --git a/gradio_demo/detectron2/projects/README.md b/detectron2/projects/README.md similarity index 100% rename from gradio_demo/detectron2/projects/README.md rename to detectron2/projects/README.md diff --git a/gradio_demo/detectron2/projects/__init__.py b/detectron2/projects/__init__.py similarity index 100% rename from gradio_demo/detectron2/projects/__init__.py rename to detectron2/projects/__init__.py diff --git a/gradio_demo/detectron2/solver/__init__.py b/detectron2/solver/__init__.py similarity index 100% rename from gradio_demo/detectron2/solver/__init__.py rename to detectron2/solver/__init__.py diff --git a/gradio_demo/detectron2/solver/build.py b/detectron2/solver/build.py similarity index 100% rename from gradio_demo/detectron2/solver/build.py rename to detectron2/solver/build.py diff --git a/gradio_demo/detectron2/solver/lr_scheduler.py b/detectron2/solver/lr_scheduler.py similarity index 100% rename from gradio_demo/detectron2/solver/lr_scheduler.py rename to detectron2/solver/lr_scheduler.py diff --git a/gradio_demo/detectron2/structures/__init__.py b/detectron2/structures/__init__.py similarity index 100% rename from gradio_demo/detectron2/structures/__init__.py rename to detectron2/structures/__init__.py diff --git a/gradio_demo/detectron2/structures/boxes.py b/detectron2/structures/boxes.py similarity index 100% rename from gradio_demo/detectron2/structures/boxes.py rename to detectron2/structures/boxes.py diff --git a/gradio_demo/detectron2/structures/image_list.py b/detectron2/structures/image_list.py similarity index 100% rename from gradio_demo/detectron2/structures/image_list.py rename to detectron2/structures/image_list.py diff --git a/gradio_demo/detectron2/structures/instances.py b/detectron2/structures/instances.py similarity index 100% rename from gradio_demo/detectron2/structures/instances.py rename to detectron2/structures/instances.py diff --git a/gradio_demo/detectron2/structures/keypoints.py b/detectron2/structures/keypoints.py similarity index 100% rename from gradio_demo/detectron2/structures/keypoints.py rename to detectron2/structures/keypoints.py diff --git a/gradio_demo/detectron2/structures/masks.py b/detectron2/structures/masks.py similarity index 100% rename from gradio_demo/detectron2/structures/masks.py rename to detectron2/structures/masks.py diff --git a/gradio_demo/detectron2/structures/rotated_boxes.py b/detectron2/structures/rotated_boxes.py similarity index 100% rename from gradio_demo/detectron2/structures/rotated_boxes.py rename to detectron2/structures/rotated_boxes.py diff --git a/gradio_demo/detectron2/tracking/__init__.py b/detectron2/tracking/__init__.py similarity index 100% rename from gradio_demo/detectron2/tracking/__init__.py rename to detectron2/tracking/__init__.py diff --git a/gradio_demo/detectron2/tracking/base_tracker.py b/detectron2/tracking/base_tracker.py similarity index 100% rename from gradio_demo/detectron2/tracking/base_tracker.py rename to detectron2/tracking/base_tracker.py diff --git a/gradio_demo/detectron2/tracking/bbox_iou_tracker.py b/detectron2/tracking/bbox_iou_tracker.py similarity index 100% rename from gradio_demo/detectron2/tracking/bbox_iou_tracker.py rename to detectron2/tracking/bbox_iou_tracker.py diff --git a/gradio_demo/detectron2/tracking/hungarian_tracker.py b/detectron2/tracking/hungarian_tracker.py similarity index 100% rename from gradio_demo/detectron2/tracking/hungarian_tracker.py rename to detectron2/tracking/hungarian_tracker.py diff --git a/gradio_demo/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py b/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py similarity index 100% rename from gradio_demo/detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py rename to detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py diff --git a/gradio_demo/detectron2/tracking/utils.py b/detectron2/tracking/utils.py similarity index 100% rename from gradio_demo/detectron2/tracking/utils.py rename to detectron2/tracking/utils.py diff --git a/gradio_demo/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py b/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py similarity index 100% rename from gradio_demo/detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py rename to detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py diff --git a/gradio_demo/detectron2/utils/README.md b/detectron2/utils/README.md similarity index 100% rename from gradio_demo/detectron2/utils/README.md rename to detectron2/utils/README.md diff --git a/gradio_demo/detectron2/utils/__init__.py b/detectron2/utils/__init__.py similarity index 100% rename from gradio_demo/detectron2/utils/__init__.py rename to detectron2/utils/__init__.py diff --git a/gradio_demo/detectron2/utils/analysis.py b/detectron2/utils/analysis.py similarity index 100% rename from gradio_demo/detectron2/utils/analysis.py rename to detectron2/utils/analysis.py diff --git a/gradio_demo/detectron2/utils/collect_env.py b/detectron2/utils/collect_env.py similarity index 100% rename from gradio_demo/detectron2/utils/collect_env.py rename to detectron2/utils/collect_env.py diff --git a/gradio_demo/detectron2/utils/colormap.py b/detectron2/utils/colormap.py similarity index 100% rename from gradio_demo/detectron2/utils/colormap.py rename to detectron2/utils/colormap.py diff --git a/gradio_demo/detectron2/utils/comm.py b/detectron2/utils/comm.py similarity index 100% rename from gradio_demo/detectron2/utils/comm.py rename to detectron2/utils/comm.py diff --git a/gradio_demo/detectron2/utils/develop.py b/detectron2/utils/develop.py similarity index 100% rename from gradio_demo/detectron2/utils/develop.py rename to detectron2/utils/develop.py diff --git a/gradio_demo/detectron2/utils/env.py b/detectron2/utils/env.py similarity index 100% rename from gradio_demo/detectron2/utils/env.py rename to detectron2/utils/env.py diff --git a/gradio_demo/detectron2/utils/events.py b/detectron2/utils/events.py similarity index 100% rename from gradio_demo/detectron2/utils/events.py rename to detectron2/utils/events.py diff --git a/gradio_demo/detectron2/utils/file_io.py b/detectron2/utils/file_io.py similarity index 100% rename from gradio_demo/detectron2/utils/file_io.py rename to detectron2/utils/file_io.py diff --git a/gradio_demo/detectron2/utils/logger.py b/detectron2/utils/logger.py similarity index 100% rename from gradio_demo/detectron2/utils/logger.py rename to detectron2/utils/logger.py diff --git a/gradio_demo/detectron2/utils/memory.py b/detectron2/utils/memory.py similarity index 100% rename from gradio_demo/detectron2/utils/memory.py rename to detectron2/utils/memory.py diff --git a/gradio_demo/detectron2/utils/registry.py b/detectron2/utils/registry.py similarity index 100% rename from gradio_demo/detectron2/utils/registry.py rename to detectron2/utils/registry.py diff --git a/gradio_demo/detectron2/utils/serialize.py b/detectron2/utils/serialize.py similarity index 100% rename from gradio_demo/detectron2/utils/serialize.py rename to detectron2/utils/serialize.py diff --git a/gradio_demo/detectron2/utils/testing.py b/detectron2/utils/testing.py similarity index 100% rename from gradio_demo/detectron2/utils/testing.py rename to detectron2/utils/testing.py diff --git a/gradio_demo/detectron2/utils/tracing.py b/detectron2/utils/tracing.py similarity index 100% rename from gradio_demo/detectron2/utils/tracing.py rename to detectron2/utils/tracing.py diff --git a/gradio_demo/detectron2/utils/video_visualizer.py b/detectron2/utils/video_visualizer.py similarity index 100% rename from gradio_demo/detectron2/utils/video_visualizer.py rename to detectron2/utils/video_visualizer.py diff --git a/gradio_demo/detectron2/utils/visualizer.py b/detectron2/utils/visualizer.py similarity index 100% rename from gradio_demo/detectron2/utils/visualizer.py rename to detectron2/utils/visualizer.py diff --git a/environment.yaml b/environment.yaml deleted file mode 100644 index 094229efacee8fb7b354eac7e4e340fe30b5fecc..0000000000000000000000000000000000000000 --- a/environment.yaml +++ /dev/null @@ -1,32 +0,0 @@ -name: idm -channels: - - pytorch - - nvidia - - defaults -dependencies: - - python=3.10.0=h12debd9_5 - - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0 - - pytorch-cuda=11.8=h7e8668a_5 - - torchaudio=2.0.2=py310_cu118 - - torchtriton=2.0.0=py310 - - torchvision=0.15.2=py310_cu118 - - pip=23.3.1=py310h06a4308_0 - - - pip: - - accelerate==0.25.0 - - torchmetrics==1.2.1 - - tqdm==4.66.1 - - transformers==4.36.2 - - diffusers==0.25.0 - - einops==0.7.0 - - bitsandbytes==0.39.0 - - scipy==1.11.1 - - opencv-python - - gradio==4.24.0 - - fvcore - - cloudpickle - - omegaconf - - pycocotools - - basicsr - - av - - onnxruntime==1.16.2 diff --git a/gradio_demo/example/cloth/04469_00.jpg b/example/cloth/04469_00.jpg similarity index 100% rename from gradio_demo/example/cloth/04469_00.jpg rename to example/cloth/04469_00.jpg diff --git a/gradio_demo/example/cloth/04743_00.jpg b/example/cloth/04743_00.jpg similarity index 100% rename from gradio_demo/example/cloth/04743_00.jpg rename to example/cloth/04743_00.jpg diff --git a/gradio_demo/example/cloth/09133_00.jpg b/example/cloth/09133_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09133_00.jpg rename to example/cloth/09133_00.jpg diff --git a/gradio_demo/example/cloth/09163_00.jpg b/example/cloth/09163_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09163_00.jpg rename to example/cloth/09163_00.jpg diff --git a/gradio_demo/example/cloth/09164_00.jpg b/example/cloth/09164_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09164_00.jpg rename to example/cloth/09164_00.jpg diff --git a/gradio_demo/example/cloth/09166_00.jpg b/example/cloth/09166_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09166_00.jpg rename to example/cloth/09166_00.jpg diff --git a/gradio_demo/example/cloth/09176_00.jpg b/example/cloth/09176_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09176_00.jpg rename to example/cloth/09176_00.jpg diff --git a/gradio_demo/example/cloth/09236_00.jpg b/example/cloth/09236_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09236_00.jpg rename to example/cloth/09236_00.jpg diff --git a/gradio_demo/example/cloth/09256_00.jpg b/example/cloth/09256_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09256_00.jpg rename to example/cloth/09256_00.jpg diff --git a/gradio_demo/example/cloth/09263_00.jpg b/example/cloth/09263_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09263_00.jpg rename to example/cloth/09263_00.jpg diff --git a/gradio_demo/example/cloth/09266_00.jpg b/example/cloth/09266_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09266_00.jpg rename to example/cloth/09266_00.jpg diff --git a/gradio_demo/example/cloth/09290_00.jpg b/example/cloth/09290_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09290_00.jpg rename to example/cloth/09290_00.jpg diff --git a/gradio_demo/example/cloth/09305_00.jpg b/example/cloth/09305_00.jpg similarity index 100% rename from gradio_demo/example/cloth/09305_00.jpg rename to example/cloth/09305_00.jpg diff --git a/gradio_demo/example/cloth/10165_00.jpg b/example/cloth/10165_00.jpg similarity index 100% rename from gradio_demo/example/cloth/10165_00.jpg rename to example/cloth/10165_00.jpg diff --git a/gradio_demo/example/cloth/14627_00.jpg b/example/cloth/14627_00.jpg similarity index 100% rename from gradio_demo/example/cloth/14627_00.jpg rename to example/cloth/14627_00.jpg diff --git a/gradio_demo/example/cloth/14673_00.jpg b/example/cloth/14673_00.jpg similarity index 100% rename from gradio_demo/example/cloth/14673_00.jpg rename to example/cloth/14673_00.jpg diff --git a/gradio_demo/example/human/00034_00.jpg b/example/human/00034_00.jpg similarity index 100% rename from gradio_demo/example/human/00034_00.jpg rename to example/human/00034_00.jpg diff --git a/gradio_demo/example/human/00035_00.jpg b/example/human/00035_00.jpg similarity index 100% rename from gradio_demo/example/human/00035_00.jpg rename to example/human/00035_00.jpg diff --git a/gradio_demo/example/human/00055_00.jpg b/example/human/00055_00.jpg similarity index 100% rename from gradio_demo/example/human/00055_00.jpg rename to example/human/00055_00.jpg diff --git a/gradio_demo/example/human/00121_00.jpg b/example/human/00121_00.jpg similarity index 100% rename from gradio_demo/example/human/00121_00.jpg rename to example/human/00121_00.jpg diff --git a/gradio_demo/example/human/01992_00.jpg b/example/human/01992_00.jpg similarity index 100% rename from gradio_demo/example/human/01992_00.jpg rename to example/human/01992_00.jpg diff --git a/gradio_demo/example/human/Jensen.jpeg b/example/human/Jensen.jpeg similarity index 100% rename from gradio_demo/example/human/Jensen.jpeg rename to example/human/Jensen.jpeg diff --git a/gradio_demo/example/human/sam1 (1).jpg b/example/human/sam1 (1).jpg similarity index 100% rename from gradio_demo/example/human/sam1 (1).jpg rename to example/human/sam1 (1).jpg diff --git a/gradio_demo/example/human/taylor-.jpg b/example/human/taylor-.jpg similarity index 100% rename from gradio_demo/example/human/taylor-.jpg rename to example/human/taylor-.jpg diff --git a/gradio_demo/example/human/will1 (1).jpg b/example/human/will1 (1).jpg similarity index 100% rename from gradio_demo/example/human/will1 (1).jpg rename to example/human/will1 (1).jpg diff --git a/hf b/hf deleted file mode 100644 index c3b9c41e2dc0597f8cea3f1a01519861900e95a9..0000000000000000000000000000000000000000 --- a/hf +++ /dev/null @@ -1,7 +0,0 @@ ------BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW -QyNTUxOQAAACAnrBvCQqFPKBjg50yOVYY4CRZFOXYtf87WU1C1+176WwAAAJjRr9Rv0a/U -bwAAAAtzc2gtZWQyNTUxOQAAACAnrBvCQqFPKBjg50yOVYY4CRZFOXYtf87WU1C1+176Ww -AAAEDgMBd8z9N6Q1/05T0A5KiyVisusVzLpxDYpBMZEq6WfiesG8JCoU8oGODnTI5VhjgJ -FkU5di1/ztZTULX7XvpbAAAAFWlxdGVjaC4yMDIyQGdtYWlsLmNvbQ== ------END OPENSSH PRIVATE KEY----- diff --git a/hf.pub b/hf.pub deleted file mode 100644 index e792b00fefbdc1e37c9c3587f2e4cd107611e373..0000000000000000000000000000000000000000 --- a/hf.pub +++ /dev/null @@ -1 +0,0 @@ -ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAICesG8JCoU8oGODnTI5VhjgJFkU5di1/ztZTULX7Xvpb iqtech.2022@gmail.com diff --git a/inference.py b/inference.py deleted file mode 100644 index 23e44ba8bd0f53dd0e3ed21bf6285e385a3c521e..0000000000000000000000000000000000000000 --- a/inference.py +++ /dev/null @@ -1,425 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal -from ip_adapter.ip_adapter import Resampler - -import argparse -import logging -import os -import torch.utils.data as data -import torchvision -import json -import accelerate -import numpy as np -import torch -from PIL import Image -import torch.nn.functional as F -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from packaging import version -from torchvision import transforms -import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline -from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer - -from diffusers.utils.import_utils import is_xformers_available - -from src.unet_hacked_tryon import UNet2DConditionModel -from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref -from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline - - - -logger = get_logger(__name__, log_level="INFO") - - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,) - parser.add_argument("--width",type=int,default=768,) - parser.add_argument("--height",type=int,default=1024,) - parser.add_argument("--num_inference_steps",type=int,default=30,) - parser.add_argument("--output_dir",type=str,default="result",) - parser.add_argument("--unpaired",action="store_true",) - parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando") - parser.add_argument("--seed", type=int, default=42,) - parser.add_argument("--test_batch_size", type=int, default=2,) - parser.add_argument("--guidance_scale",type=float,default=2.0,) - parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") - args = parser.parse_args() - - - return args - -def pil_to_tensor(images): - images = np.array(images).astype(np.float32) / 255.0 - images = torch.from_numpy(images.transpose(2, 0, 1)) - return images - - -class VitonHDTestDataset(data.Dataset): - def __init__( - self, - dataroot_path: str, - phase: Literal["train", "test"], - order: Literal["paired", "unpaired"] = "paired", - size: Tuple[int, int] = (512, 384), - ): - super(VitonHDTestDataset, self).__init__() - self.dataroot = dataroot_path - self.phase = phase - self.height = size[0] - self.width = size[1] - self.size = size - self.transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - self.toTensor = transforms.ToTensor() - - with open( - os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r" - ) as file1: - data1 = json.load(file1) - - annotation_list = [ - "sleeveLength", - "neckLine", - "item", - ] - - self.annotation_pair = {} - for k, v in data1.items(): - for elem in v: - annotation_str = "" - for template in annotation_list: - for tag in elem["tag_info"]: - if ( - tag["tag_name"] == template - and tag["tag_category"] is not None - ): - annotation_str += tag["tag_category"] - annotation_str += " " - self.annotation_pair[elem["file_name"]] = annotation_str - - self.order = order - self.toTensor = transforms.ToTensor() - - im_names = [] - c_names = [] - dataroot_names = [] - - - if phase == "train": - filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") - else: - filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") - - with open(filename, "r") as f: - for line in f.readlines(): - if phase == "train": - im_name, _ = line.strip().split() - c_name = im_name - else: - if order == "paired": - im_name, _ = line.strip().split() - c_name = im_name - else: - im_name, c_name = line.strip().split() - - im_names.append(im_name) - c_names.append(c_name) - dataroot_names.append(dataroot_path) - - self.im_names = im_names - self.c_names = c_names - self.dataroot_names = dataroot_names - self.clip_processor = CLIPImageProcessor() - def __getitem__(self, index): - c_name = self.c_names[index] - im_name = self.im_names[index] - if c_name in self.annotation_pair: - cloth_annotation = self.annotation_pair[c_name] - else: - cloth_annotation = "shirts" - cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name)) - - im_pil_big = Image.open( - os.path.join(self.dataroot, self.phase, "image", im_name) - ).resize((self.width,self.height)) - image = self.transform(im_pil_big) - - mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)) - mask = self.toTensor(mask) - mask = mask[:1] - mask = 1-mask - im_mask = image * mask - - pose_img = Image.open( - os.path.join(self.dataroot, self.phase, "image-densepose", im_name) - ) - pose_img = self.transform(pose_img) # [-1,1] - - result = {} - result["c_name"] = c_name - result["im_name"] = im_name - result["image"] = image - result["cloth_pure"] = self.transform(cloth) - result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values - result["inpaint_mask"] =1-mask - result["im_mask"] = im_mask - result["caption_cloth"] = "a photo of " + cloth_annotation - result["caption"] = "model is wearing a " + cloth_annotation - result["pose_img"] = pose_img - - return result - - def __len__(self): - # model images + cloth image - return len(self.im_names) - - - - -def main(): - args = parse_args() - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) - accelerator = Accelerator( - mixed_precision=args.mixed_precision, - project_config=accelerator_project_config, - ) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - weight_dtype = torch.float16 - # if accelerator.mixed_precision == "fp16": - # weight_dtype = torch.float16 - # args.mixed_precision = accelerator.mixed_precision - # elif accelerator.mixed_precision == "bf16": - # weight_dtype = torch.bfloat16 - # args.mixed_precision = accelerator.mixed_precision - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=torch.float16, - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet", - torch_dtype=torch.float16, - ) - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="image_encoder", - torch_dtype=torch.float16, - ) - unet_encoder = UNet2DConditionModel_ref.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet_encoder", - torch_dtype=torch.float16, - ) - text_encoder_one = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=torch.float16, - ) - text_encoder_two = CLIPTextModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - torch_dtype=torch.float16, - ) - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=None, - use_fast=False, - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=None, - use_fast=False, - ) - - - # Freeze vae and text_encoder and set unet to trainable - unet.requires_grad_(False) - vae.requires_grad_(False) - image_encoder.requires_grad_(False) - unet_encoder.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - unet_encoder.to(accelerator.device, weight_dtype) - unet.eval() - unet_encoder.eval() - - - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - test_dataset = VitonHDTestDataset( - dataroot_path=args.data_dir, - phase="test", - order="unpaired" if args.unpaired else "paired", - size=(args.height, args.width), - ) - test_dataloader = torch.utils.data.DataLoader( - test_dataset, - shuffle=False, - batch_size=args.test_batch_size, - num_workers=4, - ) - - pipe = TryonPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unet, - vae=vae, - feature_extractor= CLIPImageProcessor(), - text_encoder = text_encoder_one, - text_encoder_2 = text_encoder_two, - tokenizer = tokenizer_one, - tokenizer_2 = tokenizer_two, - scheduler = noise_scheduler, - image_encoder=image_encoder, - unet_encoder = unet_encoder, - torch_dtype=torch.float16, - ).to(accelerator.device) - - # pipe.enable_sequential_cpu_offload() - # pipe.enable_model_cpu_offload() - # pipe.enable_vae_slicing() - - - - with torch.no_grad(): - # Extract the images - with torch.cuda.amp.autocast(): - with torch.no_grad(): - for sample in test_dataloader: - img_emb_list = [] - for i in range(sample['cloth'].shape[0]): - img_emb_list.append(sample['cloth'][i]) - - prompt = sample["caption"] - - num_prompts = sample['cloth'].shape[0] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_embeds = torch.cat(img_emb_list,dim=0) - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = pipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - - prompt = sample["caption_cloth"] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - - with torch.inference_mode(): - ( - prompt_embeds_c, - _, - _, - _, - ) = pipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=negative_prompt, - ) - - - - generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None - images = pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=args.num_inference_steps, - generator=generator, - strength = 1.0, - pose_img = sample['pose_img'], - text_embeds_cloth=prompt_embeds_c, - cloth = sample["cloth_pure"].to(accelerator.device), - mask_image=sample['inpaint_mask'], - image=(sample['image']+1.0)/2.0, - height=args.height, - width=args.width, - guidance_scale=args.guidance_scale, - ip_adapter_image = image_embeds, - )[0] - - - for i in range(len(images)): - x_sample = pil_to_tensor(images[i]) - torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i])) - - - - -if __name__ == "__main__": - main() diff --git a/inference.sh b/inference.sh deleted file mode 100644 index 89d6747bcee011997a59f070dfba03b51958cad3..0000000000000000000000000000000000000000 --- a/inference.sh +++ /dev/null @@ -1,34 +0,0 @@ -#VITON-HD -##paired setting -accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \ - --width 768 --height 1024 --num_inference_steps 30 \ - --output_dir "result" --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \ - --seed 42 --test_batch_size 2 --guidance_scale 2.0 - - -##unpaired setting -accelerate launch inference.py --pretrained_model_name_or_path "yisol/IDM-VTON" \ - --width 768 --height 1024 --num_inference_steps 30 \ - --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/Dataset/zalando" \ - --seed 42 --test_batch_size 2 --guidance_scale 2.0 - - - -#DressCode -##upper_body -accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \ - --width 768 --height 1024 --num_inference_steps 30 \ - --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \ - --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "upper_body" - -##lower_body -accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \ - --width 768 --height 1024 --num_inference_steps 30 \ - --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \ - --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "lower_body" - -##dresses -accelerate launch inference_dc.py --pretrained_model_name_or_path "yisol/IDM-VTON" \ - --width 768 --height 1024 --num_inference_steps 30 \ - --output_dir "result" --unpaired --data_dir "/home/omnious/workspace/yisol/DressCode" \ - --seed 42 --test_batch_size 2 --guidance_scale 2.0 --category "dresses" \ No newline at end of file diff --git a/inference_dc.py b/inference_dc.py deleted file mode 100644 index fd8e78cd6a206b554189068f95cad3d2dd0eab34..0000000000000000000000000000000000000000 --- a/inference_dc.py +++ /dev/null @@ -1,578 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal -from ip_adapter.ip_adapter import Resampler - -import argparse -import logging -import os -import torch.utils.data as data -import torchvision -import json -import accelerate -import numpy as np -import torch -from PIL import Image, ImageDraw -import torch.nn.functional as F -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed -from packaging import version -from torchvision import transforms -import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline -from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer -import cv2 -from diffusers.utils.import_utils import is_xformers_available -from numpy.linalg import lstsq - -from src.unet_hacked_tryon import UNet2DConditionModel -from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref -from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline - - - -logger = get_logger(__name__, log_level="INFO") - -label_map={ - "background": 0, - "hat": 1, - "hair": 2, - "sunglasses": 3, - "upper_clothes": 4, - "skirt": 5, - "pants": 6, - "dress": 7, - "belt": 8, - "left_shoe": 9, - "right_shoe": 10, - "head": 11, - "left_leg": 12, - "right_leg": 13, - "left_arm": 14, - "right_arm": 15, - "bag": 16, - "scarf": 17, -} - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,) - parser.add_argument("--width",type=int,default=768,) - parser.add_argument("--height",type=int,default=1024,) - parser.add_argument("--num_inference_steps",type=int,default=30,) - parser.add_argument("--output_dir",type=str,default="result",) - parser.add_argument("--category",type=str,default="upper_body",choices=["upper_body", "lower_body", "dresses"]) - parser.add_argument("--unpaired",action="store_true",) - parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando") - parser.add_argument("--seed", type=int, default=42,) - parser.add_argument("--test_batch_size", type=int, default=2,) - parser.add_argument("--guidance_scale",type=float,default=2.0,) - parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],) - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") - args = parser.parse_args() - - - return args - -def pil_to_tensor(images): - images = np.array(images).astype(np.float32) / 255.0 - images = torch.from_numpy(images.transpose(2, 0, 1)) - return images - - -class DresscodeTestDataset(data.Dataset): - def __init__( - self, - dataroot_path: str, - phase: Literal["train", "test"], - order: Literal["paired", "unpaired"] = "paired", - category = "upper_body", - size: Tuple[int, int] = (512, 384), - ): - super(DresscodeTestDataset, self).__init__() - self.dataroot = os.path.join(dataroot_path,category) - self.phase = phase - self.height = size[0] - self.width = size[1] - self.size = size - self.transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - self.toTensor = transforms.ToTensor() - self.order = order - self.radius = 5 - self.category = category - im_names = [] - c_names = [] - - - if phase == "train": - filename = os.path.join(dataroot_path,category, f"{phase}_pairs.txt") - else: - filename = os.path.join(dataroot_path,category, f"{phase}_pairs_{order}.txt") - - with open(filename, "r") as f: - for line in f.readlines(): - im_name, c_name = line.strip().split() - - im_names.append(im_name) - c_names.append(c_name) - - - file_path = os.path.join(dataroot_path,category,"dc_caption.txt") - - self.annotation_pair = {} - with open(file_path, "r") as file: - for line in file: - parts = line.strip().split(" ") - self.annotation_pair[parts[0]] = ' '.join(parts[1:]) - - - self.im_names = im_names - self.c_names = c_names - self.clip_processor = CLIPImageProcessor() - def __getitem__(self, index): - c_name = self.c_names[index] - im_name = self.im_names[index] - if c_name in self.annotation_pair: - cloth_annotation = self.annotation_pair[c_name] - else: - cloth_annotation = self.category - cloth = Image.open(os.path.join(self.dataroot, "images", c_name)) - - im_pil_big = Image.open( - os.path.join(self.dataroot, "images", im_name) - ).resize((self.width,self.height)) - image = self.transform(im_pil_big) - - - - - skeleton = Image.open(os.path.join(self.dataroot, 'skeletons', im_name.replace("_0", "_5"))) - skeleton = skeleton.resize((self.width, self.height)) - skeleton = self.transform(skeleton) - - # Label Map - parse_name = im_name.replace('_0.jpg', '_4.png') - im_parse = Image.open(os.path.join(self.dataroot, 'label_maps', parse_name)) - im_parse = im_parse.resize((self.width, self.height), Image.NEAREST) - parse_array = np.array(im_parse) - - # Load pose points - pose_name = im_name.replace('_0.jpg', '_2.json') - with open(os.path.join(self.dataroot, 'keypoints', pose_name), 'r') as f: - pose_label = json.load(f) - pose_data = pose_label['keypoints'] - pose_data = np.array(pose_data) - pose_data = pose_data.reshape((-1, 4)) - - point_num = pose_data.shape[0] - pose_map = torch.zeros(point_num, self.height, self.width) - r = self.radius * (self.height / 512.0) - for i in range(point_num): - one_map = Image.new('L', (self.width, self.height)) - draw = ImageDraw.Draw(one_map) - point_x = np.multiply(pose_data[i, 0], self.width / 384.0) - point_y = np.multiply(pose_data[i, 1], self.height / 512.0) - if point_x > 1 and point_y > 1: - draw.rectangle((point_x - r, point_y - r, point_x + r, point_y + r), 'white', 'white') - one_map = self.toTensor(one_map) - pose_map[i] = one_map[0] - - agnostic_mask = self.get_agnostic(parse_array, pose_data, self.category, (self.width,self.height)) - # agnostic_mask = transforms.functional.resize(agnostic_mask, (self.height, self.width), - # interpolation=transforms.InterpolationMode.NEAREST) - - mask = 1 - agnostic_mask - im_mask = image * agnostic_mask - - pose_img = Image.open( - os.path.join(self.dataroot, "image-densepose", im_name) - ) - pose_img = self.transform(pose_img) # [-1,1] - - result = {} - result["c_name"] = c_name - result["im_name"] = im_name - result["image"] = image - result["cloth_pure"] = self.transform(cloth) - result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values - result["inpaint_mask"] =mask - result["im_mask"] = im_mask - result["caption_cloth"] = "a photo of " + cloth_annotation - result["caption"] = "model is wearing a " + cloth_annotation - result["pose_img"] = pose_img - - return result - - def __len__(self): - # model images + cloth image - return len(self.im_names) - - - - - def get_agnostic(self,parse_array, pose_data, category, size): - parse_shape = (parse_array > 0).astype(np.float32) - - parse_head = (parse_array == 1).astype(np.float32) + \ - (parse_array == 2).astype(np.float32) + \ - (parse_array == 3).astype(np.float32) + \ - (parse_array == 11).astype(np.float32) - - parser_mask_fixed = (parse_array == label_map["hair"]).astype(np.float32) + \ - (parse_array == label_map["left_shoe"]).astype(np.float32) + \ - (parse_array == label_map["right_shoe"]).astype(np.float32) + \ - (parse_array == label_map["hat"]).astype(np.float32) + \ - (parse_array == label_map["sunglasses"]).astype(np.float32) + \ - (parse_array == label_map["scarf"]).astype(np.float32) + \ - (parse_array == label_map["bag"]).astype(np.float32) - - parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32) - - arms = (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32) - - if category == 'dresses': - label_cat = 7 - parse_mask = (parse_array == 7).astype(np.float32) + \ - (parse_array == 12).astype(np.float32) + \ - (parse_array == 13).astype(np.float32) - parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) - - elif category == 'upper_body': - label_cat = 4 - parse_mask = (parse_array == 4).astype(np.float32) - - parser_mask_fixed += (parse_array == label_map["skirt"]).astype(np.float32) + \ - (parse_array == label_map["pants"]).astype(np.float32) - - parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) - elif category == 'lower_body': - label_cat = 6 - parse_mask = (parse_array == 6).astype(np.float32) + \ - (parse_array == 12).astype(np.float32) + \ - (parse_array == 13).astype(np.float32) - - parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \ - (parse_array == 14).astype(np.float32) + \ - (parse_array == 15).astype(np.float32) - parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed)) - - parse_head = torch.from_numpy(parse_head) # [0,1] - parse_mask = torch.from_numpy(parse_mask) # [0,1] - parser_mask_fixed = torch.from_numpy(parser_mask_fixed) - parser_mask_changeable = torch.from_numpy(parser_mask_changeable) - - # dilation - parse_without_cloth = np.logical_and(parse_shape, np.logical_not(parse_mask)) - parse_mask = parse_mask.cpu().numpy() - - width = size[0] - height = size[1] - - im_arms = Image.new('L', (width, height)) - arms_draw = ImageDraw.Draw(im_arms) - if category == 'dresses' or category == 'upper_body': - shoulder_right = tuple(np.multiply(pose_data[2, :2], height / 512.0)) - shoulder_left = tuple(np.multiply(pose_data[5, :2], height / 512.0)) - elbow_right = tuple(np.multiply(pose_data[3, :2], height / 512.0)) - elbow_left = tuple(np.multiply(pose_data[6, :2], height / 512.0)) - wrist_right = tuple(np.multiply(pose_data[4, :2], height / 512.0)) - wrist_left = tuple(np.multiply(pose_data[7, :2], height / 512.0)) - if wrist_right[0] <= 1. and wrist_right[1] <= 1.: - if elbow_right[0] <= 1. and elbow_right[1] <= 1.: - arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right], 'white', 30, 'curve') - else: - arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right], 'white', 30, - 'curve') - elif wrist_left[0] <= 1. and wrist_left[1] <= 1.: - if elbow_left[0] <= 1. and elbow_left[1] <= 1.: - arms_draw.line([shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30, 'curve') - else: - arms_draw.line([elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', 30, - 'curve') - else: - arms_draw.line([wrist_left, elbow_left, shoulder_left, shoulder_right, elbow_right, wrist_right], 'white', - 30, 'curve') - - if height > 512: - im_arms = cv2.dilate(np.float32(im_arms), np.ones((10, 10), np.uint16), iterations=5) - elif height > 256: - im_arms = cv2.dilate(np.float32(im_arms), np.ones((5, 5), np.uint16), iterations=5) - hands = np.logical_and(np.logical_not(im_arms), arms) - parse_mask += im_arms - parser_mask_fixed += hands - - # delete neck - parse_head_2 = torch.clone(parse_head) - if category == 'dresses' or category == 'upper_body': - points = [] - points.append(np.multiply(pose_data[2, :2], height / 512.0)) - points.append(np.multiply(pose_data[5, :2], height / 512.0)) - x_coords, y_coords = zip(*points) - A = np.vstack([x_coords, np.ones(len(x_coords))]).T - m, c = lstsq(A, y_coords, rcond=None)[0] - for i in range(parse_array.shape[1]): - y = i * m + c - parse_head_2[int(y - 20 * (height / 512.0)):, i] = 0 - - parser_mask_fixed = np.logical_or(parser_mask_fixed, np.array(parse_head_2, dtype=np.uint16)) - parse_mask += np.logical_or(parse_mask, np.logical_and(np.array(parse_head, dtype=np.uint16), - np.logical_not(np.array(parse_head_2, dtype=np.uint16)))) - - if height > 512: - parse_mask = cv2.dilate(parse_mask, np.ones((20, 20), np.uint16), iterations=5) - elif height > 256: - parse_mask = cv2.dilate(parse_mask, np.ones((10, 10), np.uint16), iterations=5) - else: - parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5) - parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask)) - parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed) - agnostic_mask = parse_mask_total.unsqueeze(0) - return agnostic_mask - - - - -def main(): - args = parse_args() - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) - accelerator = Accelerator( - mixed_precision=args.mixed_precision, - project_config=accelerator_project_config, - ) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - weight_dtype = torch.float16 - # if accelerator.mixed_precision == "fp16": - # weight_dtype = torch.float16 - # args.mixed_precision = accelerator.mixed_precision - # elif accelerator.mixed_precision == "bf16": - # weight_dtype = torch.bfloat16 - # args.mixed_precision = accelerator.mixed_precision - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", - torch_dtype=torch.float16, - ) - unet = UNet2DConditionModel.from_pretrained( - "yisol/IDM-VTON-DC", - subfolder="unet", - torch_dtype=torch.float16, - ) - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="image_encoder", - torch_dtype=torch.float16, - ) - UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="unet_encoder", - torch_dtype=torch.float16, - ) - text_encoder_one = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder", - torch_dtype=torch.float16, - ) - text_encoder_two = CLIPTextModelWithProjection.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="text_encoder_2", - torch_dtype=torch.float16, - ) - tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - revision=None, - use_fast=False, - ) - tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer_2", - revision=None, - use_fast=False, - ) - - - # Freeze vae and text_encoder and set unet to trainable - unet.requires_grad_(False) - vae.requires_grad_(False) - image_encoder.requires_grad_(False) - UNet_Encoder.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) - UNet_Encoder.to(accelerator.device, weight_dtype) - unet.eval() - UNet_Encoder.eval() - - - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - xformers_version = version.parse(xformers.__version__) - if xformers_version == version.parse("0.0.16"): - logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." - ) - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - test_dataset = DresscodeTestDataset( - dataroot_path=args.data_dir, - phase="test", - order="unpaired" if args.unpaired else "paired", - category = args.category, - size=(args.height, args.width), - ) - test_dataloader = torch.utils.data.DataLoader( - test_dataset, - shuffle=False, - batch_size=args.test_batch_size, - num_workers=4, - ) - - pipe = TryonPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unet, - vae=vae, - feature_extractor= CLIPImageProcessor(), - text_encoder = text_encoder_one, - text_encoder_2 = text_encoder_two, - tokenizer = tokenizer_one, - tokenizer_2 = tokenizer_two, - scheduler = noise_scheduler, - image_encoder=image_encoder, - torch_dtype=torch.float16, - ).to(accelerator.device) - pipe.unet_encoder = UNet_Encoder - - # pipe.enable_sequential_cpu_offload() - # pipe.enable_model_cpu_offload() - # pipe.enable_vae_slicing() - - - - with torch.no_grad(): - # Extract the images - with torch.cuda.amp.autocast(): - with torch.no_grad(): - for sample in test_dataloader: - img_emb_list = [] - for i in range(sample['cloth'].shape[0]): - img_emb_list.append(sample['cloth'][i]) - - prompt = sample["caption"] - - num_prompts = sample['cloth'].shape[0] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_embeds = torch.cat(img_emb_list,dim=0) - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = pipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - - prompt = sample["caption_cloth"] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - - with torch.inference_mode(): - ( - prompt_embeds_c, - _, - _, - _, - ) = pipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=negative_prompt, - ) - - - - generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None - images = pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=args.num_inference_steps, - generator=generator, - strength = 1.0, - pose_img = sample['pose_img'], - text_embeds_cloth=prompt_embeds_c, - cloth = sample["cloth_pure"].to(accelerator.device), - mask_image=sample['inpaint_mask'], - image=(sample['image']+1.0)/2.0, - height=args.height, - width=args.width, - guidance_scale=args.guidance_scale, - ip_adapter_image = image_embeds, - )[0] - - - for i in range(len(images)): - x_sample = pil_to_tensor(images[i]) - torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i])) - - - - -if __name__ == "__main__": - main() diff --git a/ip_adapter/__init__.py b/ip_adapter/__init__.py index b275952105f50616770a83609ee1eada68bffd90..3b1f1ff4e54e93ada7e85abc0f6687c5ecd3a338 100644 --- a/ip_adapter/__init__.py +++ b/ip_adapter/__init__.py @@ -1,4 +1,4 @@ -from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterPlus_Lora,IPAdapterPlus_Lora_up +from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull __all__ = [ "IPAdapter", @@ -6,6 +6,4 @@ __all__ = [ "IPAdapterPlusXL", "IPAdapterXL", "IPAdapterFull", - "IPAdapterPlus_Lora", - 'IPAdapterPlus_Lora_up', ] diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py index b6c40c41570a0ce05a4a373645747f3ee29275e7..07fd6d8e2fc622b1e5a2bdd52119566b8b5219ef 100644 --- a/ip_adapter/attention_processor.py +++ b/ip_adapter/attention_processor.py @@ -2,11 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.models.lora import LoRACompatibleLinear -from diffusers.models.lora import LoRALinearLayer,LoRAConv2dLayer -from einops import rearrange -from diffusers.models.transformer_2d import Transformer2DModel class AttnProcessor(nn.Module): r""" @@ -97,1673 +93,6 @@ class IPAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor2_0(torch.nn.Module): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__( - self, - hidden_size=None, - cross_attention_dim=None, - ): - super().__init__() - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale= 1.0, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - # args = (scale, ) - args = () - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor2_0_attn(torch.nn.Module): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__( - self, - hidden_size=None, - cross_attention_dim=None, - ): - super().__init__() - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - is_cloth_pass=False, - cloth = None, - up_cnt=None, - mid_cnt=None, - down_cnt=None, - inside_up=None, - inside_down=None, - cloth_text=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor2_0_Lora(torch.nn.Module): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__( - self, - scale_lora =1.0, - hidden_size=None, - cross_attention_dim=None, - ): - super().__init__() - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - self.scale_lora = scale_lora - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - if hasattr(attn,'q_lora'): - query = attn.to_q(hidden_states) - q_lora = attn.q_lora(hidden_states) - query = query + self.scale_lora * q_lora - else: - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - - if hasattr(attn,'k_lora'): - key = attn.to_k(hidden_states) - k_lora = attn.k_lora(hidden_states) - key = key + self.scale_lora * k_lora - else: - key = attn.to_k(hidden_states) - - if hasattr(attn,'v_lora'): - value = attn.to_v(encoder_hidden_states) - v_lora = attn.v_lora(hidden_states) - value = value + self.scale_lora * v_lora - else: - value = attn.to_v(encoder_hidden_states) - - - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - - - if hasattr(attn,'out_lora'): - hidden_states = attn.to_out[0](hidden_states) - out_lora = attn.out_lora(hidden_states) - hidden_states = hidden_states+ self.scale_lora*out_lora - else: - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class IPAttnProcessor_clothpass_noip(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - is_cloth_pass=False, - cloth = None, - up_cnt=None, - mid_cnt=None, - down_cnt=None, - inside=None, - ): - - if is_cloth_pass or up_cnt is None: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - else: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - if attention_mask is not None: - print('!!!!attention_mask is not NoNE!!!!') - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - # print(up_cnt*3 + inside) - cloth_feature = cloth[up_cnt*3 + inside-1] - cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() - # print(cloth_feature.shape) - # print(self.hidden_size) - c_key = self.to_k_c(cloth_feature) - c_value = self.to_v_c(cloth_feature) - - - c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # print(ip_value.shape) - #$$ attn_mask? - hidden_states_cloth = F.scaled_dot_product_attention( - query, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_cloth = hidden_states_cloth.to(query.dtype) - - hidden_states = hidden_states + self.scale * hidden_states_cloth - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - -class IPAttnProcessor_clothpass(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - is_cloth_pass=False, - cloth = None, - up_cnt=None, - mid_cnt=None, - down_cnt=None, - inside=None, - ): - - if is_cloth_pass : - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - elif up_cnt is None: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - else: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - if attention_mask is not None: - print('!!!!attention_mask is not NoNE!!!!') - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - # print(up_cnt*3 + inside) - cloth_feature = cloth[up_cnt*3 + inside-1] - cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() - - - - - - - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - - c_key = self.to_k_c(cloth_feature) - c_value = self.to_v_c(cloth_feature) - - c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - ip_hidden_states = F.scaled_dot_product_attention( - ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - - - - - -class IPAttnProcessor_clothpass_extend(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - is_cloth_pass=False, - cloth = None, - up_cnt=None, - mid_cnt=None, - down_cnt=None, - inside_up=None, - inside_down=None, - ): - if is_cloth_pass : - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - # elif up_cnt is None or down_cnt is None: - # residual = hidden_states - - # if attn.spatial_norm is not None: - # hidden_states = attn.spatial_norm(hidden_states, temb) - - # input_ndim = hidden_states.ndim - # if input_ndim == 4: - # batch_size, channel, height, width = hidden_states.shape - # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # batch_size, sequence_length, _ = ( - # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - # ) - - # if attention_mask is not None: - # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # # scaled_dot_product_attention expects attention_mask shape to be - # # (batch, heads, source_length, target_length) - # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - # if attn.group_norm is not None: - # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - # query = attn.to_q(hidden_states) - # if encoder_hidden_states is None: - # encoder_hidden_states = hidden_states - # else: - # # get encoder_hidden_states, ip_hidden_states - # end_pos = encoder_hidden_states.shape[1] - self.num_tokens - # encoder_hidden_states, ip_hidden_states = ( - # encoder_hidden_states[:, :end_pos, :], - # encoder_hidden_states[:, end_pos:, :], - # ) - # if attn.norm_cross: - # encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # key = attn.to_k(encoder_hidden_states) - # value = attn.to_v(encoder_hidden_states) - - # inner_dim = key.shape[-1] - # head_dim = inner_dim // attn.heads - - # query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # # the output of sdp = (batch, num_heads, seq_len, head_dim) - # # TODO: add support for attn.scale when we move to Torch 2.1 - # hidden_states = F.scaled_dot_product_attention( - # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - # ) - - # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - # hidden_states = hidden_states.to(query.dtype) - - # # for ip-adapter - # ip_key = self.to_k_ip(ip_hidden_states) - # ip_value = self.to_v_ip(ip_hidden_states) - - # ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # # the output of sdp = (batch, num_heads, seq_len, head_dim) - # # TODO: add support for attn.scale when we move to Torch 2.1 - # ip_hidden_states = F.scaled_dot_product_attention( - # query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - # ) - - # ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - # ip_hidden_states = ip_hidden_states.to(query.dtype) - - - # hidden_states = hidden_states + self.scale * ip_hidden_states - - # # linear proj - # hidden_states = attn.to_out[0](hidden_states) - # # dropout - # hidden_states = attn.to_out[1](hidden_states) - - # if input_ndim == 4: - # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - # if attn.residual_connection: - # hidden_states = hidden_states + residual - - # hidden_states = hidden_states / attn.rescale_output_factor - - # return hidden_states - elif down_cnt is not None or up_cnt is not None or mid_cnt is not None: - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - print('!!!!attention_mask is not NoNE!!!!') - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # # for ip-adapter - # print(self.hidden_size) - # # print(up_cnt*3 + inside) - # print(inside_down) - cloth_feature = cloth[inside_down] - # print(cloth_feature.shape) - # if down_cnt is not None: - # # print("up_index") - # cloth_feature = cloth[down_cnt*3 + inside_down+1] - # # print(up_cnt*3 + inside_up) - # elif mid_cnt is not None: - # cloth_feature = cloth[9] - # else: - # cloth_feature = cloth[11+up_cnt*3 + inside_up] - # print("down_index") - # print(down_cnt*3 + inside_down) - - cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() - # print(cloth_feature.shape) - # print(self.hidden_size) - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - - c_key = self.to_k_c(cloth_feature) - c_value = self.to_v_c(cloth_feature) - - c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - ip_hidden_states = F.scaled_dot_product_attention( - ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - else: - assert(False) - - - -class IPAttnProcessorMulti2_0_2(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - is_cloth_pass=False, - cloth = None, - up_cnt=None, - mid_cnt=None, - down_cnt=None, - inside=None, - cloth_text=None, - ): - - if is_cloth_pass or up_cnt is None: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - else: - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - # print(up_cnt) - # print("hidden_states.shape") - # print(hidden_states.shape) - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - # print(up_cnt*3 + inside) - cloth_feature = cloth[up_cnt*3 + inside-1] - cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() - # print(cloth_feature.shape) - # print(self.hidden_size) - - # print("cloth_feature.shape") - # print(cloth_feature.shape) - query_cloth = self.q_additional(cloth_feature) - query_cloth = query_cloth.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_key = self.to_k_ip(cloth_text) - ip_value = self.to_v_ip(cloth_text) - - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # print(ip_value.shape) - #$$ attn_mask? - hidden_states_cloth = F.scaled_dot_product_attention( - query_cloth, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states_cloth = hidden_states_cloth.to(query.dtype) - - ip_key = self.k_additional(hidden_states_cloth) - ip_value = self.v_additional(hidden_states_cloth) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states - - - - - - - - - -class IPAttnProcessor2_0_paint(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - if cross_attention_dim==None: - print("cross_attention_dim is none") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - - # ####### - - # # for ip-adapter - ip_key = self.to_k_ip(encoder_hidden_states) - ip_value = self.to_v_ip(encoder_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - - - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - - hidden_states = hidden_states + self.scale * ip_hidden_states - - - # ####### - - - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - - - - -class IPAttnProcessor2_0_variant(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim @@ -1795,12 +124,7 @@ class IPAttnProcessor2_0_variant(torch.nn.Module): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) @@ -1822,42 +146,27 @@ class IPAttnProcessor2_0_variant(torch.nn.Module): key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) - # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - # hidden_states = hidden_states.to(query.dtype) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - hidden_states, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - with torch.no_grad(): - self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - hidden_states = ip_hidden_states + hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -1875,177 +184,20 @@ class IPAttnProcessor2_0_variant(torch.nn.Module): return hidden_states - -class IPAttnProcessor2_0(torch.nn.Module): +class AttnProcessor2_0(torch.nn.Module): r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( + def __init__( self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0 + hidden_size=None, + cross_attention_dim=None, ): - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - # args = (scale, ) - args = () - - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - with torch.no_grad(): - self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - - - - - - - -class IPAttnProcessor_referencenet_2_0(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4,attn_head_dim=10): super().__init__() - if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - self.attn_head_dim=attn_head_dim - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( self, attn, @@ -2082,15 +234,8 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module): if encoder_hidden_states is None: encoder_hidden_states = hidden_states - else: - # get encoder_hidden_states, ip_hidden_states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - if attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -2112,26 +257,6 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - with torch.no_grad(): - self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout @@ -2148,8 +273,7 @@ class IPAttnProcessor_referencenet_2_0(torch.nn.Module): return hidden_states - -class IPAttnProcessor2_0_Lora(torch.nn.Module): +class IPAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: @@ -2163,7 +287,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): The context length of the image features. """ - def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, scale_lora=1.0, rank = 4,num_tokens=4): + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -2172,14 +296,10 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale - self.scale_lora = scale_lora self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_k_ip_lora = LoRALinearLayer(in_features=self.to_k_ip.in_features, out_features=self.to_k_ip.out_features, rank=rank) - self.to_v_ip_lora =LoRALinearLayer(in_features=self.to_v_ip.in_features, out_features=self.to_v_ip.out_features, rank=rank) - + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, @@ -2213,12 +333,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if hasattr(attn,'q_lora'): - query = attn.to_q(hidden_states) - q_lora = attn.q_lora(hidden_states) - query = query + self.scale_lora * q_lora - else: - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -2232,21 +347,8 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - if hasattr(attn,'k_lora'): - key = attn.to_k(encoder_hidden_states) - k_lora = attn.k_lora(encoder_hidden_states) - key = key + self.scale_lora * k_lora - else: - key = attn.to_k(encoder_hidden_states) - - if hasattr(attn,'v_lora'): - value = attn.to_v(encoder_hidden_states) - v_lora = attn.v_lora(encoder_hidden_states) - value = value + self.scale_lora * v_lora - else: - value = attn.to_v(encoder_hidden_states) - - + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -2266,14 +368,8 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): hidden_states = hidden_states.to(query.dtype) # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_key_lora = self.to_k_ip_lora(ip_hidden_states) - ip_key = ip_key + self.scale_lora * ip_key_lora ip_value = self.to_v_ip(ip_hidden_states) - ip_value_lora = self.to_v_ip_lora(ip_hidden_states) - ip_value = ip_value + self.scale_lora * ip_value_lora - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -2283,6 +379,9 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) @@ -2290,14 +389,7 @@ class IPAttnProcessor2_0_Lora(torch.nn.Module): hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj - - if hasattr(attn,'out_lora'): - hidden_states = attn.to_out[0](hidden_states) - out_lora = attn.out_lora(hidden_states) - hidden_states = hidden_states+ self.scale_lora*out_lora - else: - hidden_states = attn.to_out[0](hidden_states) - + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/ip_adapter/attention_processor_faceid.py b/ip_adapter/attention_processor_faceid.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc1ca1e5f45e1da21be543f8d28d5de2bb0eba9 --- /dev/null +++ b/ip_adapter/attention_processor_faceid.py @@ -0,0 +1,427 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.lora import LoRALinearLayer + + +class LoRAAttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor2_0(nn.Module): + + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + self.num_tokens = num_tokens + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + #query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # for text + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/ip_adapter/custom_pipelines.py b/ip_adapter/custom_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..7d43d2c34db9b83f6148fac53425a9fd4c60fc93 --- /dev/null +++ b/ip_adapter/custom_pipelines.py @@ -0,0 +1,394 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + +from .utils import is_torch2_available + +if is_torch2_available(): + from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor +else: + from .attention_processor import IPAttnProcessor + + +class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): + def set_scale(self, scale): + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + @torch.no_grad() + def __call__( # noqa: C901 + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # get init conditioning scale + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + conditioning_scale = attn_processor.scale + break + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end): + self.set_scale(0.0) + else: + self.set_scale(conditioning_scale) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if output_type != "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/ip_adapter/ip_adapter.py b/ip_adapter/ip_adapter.py index d092feb1878aa29829af968091c61990002898ed..dcb8824aebae0554ee51b363d56a40d022edacdc 100644 --- a/ip_adapter/ip_adapter.py +++ b/ip_adapter/ip_adapter.py @@ -8,7 +8,7 @@ from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from .utils import is_torch2_available +from .utils import is_torch2_available, get_generator if is_torch2_available(): from .attention_processor import ( @@ -20,11 +20,9 @@ if is_torch2_available(): from .attention_processor import ( IPAttnProcessor2_0 as IPAttnProcessor, ) - from .attention_processor import IPAttnProcessor2_0_Lora -# else: -# from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor from .resampler import Resampler -from diffusers.models.lora import LoRALinearLayer class ImageProjModel(torch.nn.Module): @@ -33,6 +31,7 @@ class ImageProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() + self.generator = None self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) @@ -123,49 +122,19 @@ class IPAdapter: self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) def load_ip_adapter(self): - if self.ip_ckpt is not None: - if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = torch.load(self.ip_ckpt, map_location="cpu") - self.image_proj_model.load_state_dict(state_dict["image_proj"]) - ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) - ip_layers.load_state_dict(state_dict["ip_adapter"]) - - - # def load_ip_adapter(self): - # if self.ip_ckpt is not None: - # if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": - # state_dict = {"image_proj_model": {}, "ip_adapter": {}} - # with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: - # for key in f.keys(): - # if key.startswith("image_proj_model."): - # state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key) - # elif key.startswith("ip_adapter."): - # state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - # else: - # state_dict = torch.load(self.ip_ckpt, map_location="cpu") - - # tmp1 = {} - # for k,v in state_dict.items(): - # if 'image_proj_model' in k: - # tmp1[k.replace('image_proj_model.','')] = v - # self.image_proj_model.load_state_dict(tmp1, strict=True) - # # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) - # tmp2 = {} - # for k,v in state_dict.ites(): - # if 'adapter_mode' in k: - # tmp1[k] = v - - # print(ip_layers.state_dict()) - # ip_layers.load_state_dict(state_dict,strict=False) - + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) @torch.inference_mode() def get_image_embeds(self, pil_image=None, clip_image_embeds=None): @@ -180,19 +149,6 @@ class IPAdapter: uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds - def get_image_embeds_train(self, pil_image=None, clip_image_embeds=None): - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds - else: - clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32) - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) - return image_prompt_embeds, uncond_image_prompt_embeds - - def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): @@ -208,7 +164,7 @@ class IPAdapter: num_samples=4, seed=None, guidance_scale=7.5, - num_inference_steps=50, + num_inference_steps=30, **kwargs, ): self.set_scale(scale) @@ -248,7 +204,8 @@ class IPAdapter: prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -264,71 +221,6 @@ class IPAdapter: class IPAdapterXL(IPAdapter): """SDXL""" - def generate_test( - self, - pil_image, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - num_inference_steps=30, - **kwargs, - ): - self.set_scale(scale) - - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.pipe.encode_prompt( - prompt, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - - # with torch.autocast("cuda"): - # images = self.pipe( - # prompt_embeds=prompt_embeds, - # negative_prompt_embeds=negative_prompt_embeds, - # pooled_prompt_embeds=pooled_prompt_embeds, - # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - # num_inference_steps=num_inference_steps, - # generator=generator, - # **kwargs, - # ).images - - return images - - def generate( self, pil_image, @@ -376,98 +268,24 @@ class IPAdapterXL(IPAdapter): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + self.generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_inference_steps=num_inference_steps, - generator=generator, + generator=self.generator, **kwargs, ).images - - # with torch.autocast("cuda"): - # images = self.pipe( - # prompt_embeds=prompt_embeds, - # negative_prompt_embeds=negative_prompt_embeds, - # pooled_prompt_embeds=pooled_prompt_embeds, - # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - # num_inference_steps=num_inference_steps, - # generator=generator, - # **kwargs, - # ).images - return images class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" - def generate( - self, - pil_image=None, - clip_image_embeds=None, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - guidance_scale=7.5, - num_inference_steps=50, - **kwargs, - ): - self.set_scale(scale) - - if pil_image is not None: - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - else: - num_prompts = clip_image_embeds.size(0) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( - pil_image=pil_image, clip_image=clip_image_embeds - ) - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( - prompt, - device=self.device, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images - - def init_proj(self): image_proj_model = Resampler( dim=self.pipe.unet.config.cross_attention_dim, @@ -482,269 +300,12 @@ class IPAdapterPlus(IPAdapter): return image_proj_model @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - else: - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds - - - - -class IPAdapterPlus_Lora(IPAdapter): - """IP-Adapter with fine-grained features""" - - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): - self.rank = rank - super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) - - - def generate( - self, - pil_image=None, - clip_image_embeds=None, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - guidance_scale=7.5, - num_inference_steps=50, - **kwargs, - ): - self.set_scale(scale) - - if pil_image is not None: - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - else: - num_prompts = clip_image_embeds.size(0) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( - pil_image=pil_image, clip_image=clip_image_embeds - ) - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( - prompt, - device=self.device, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images - - - def init_proj(self): - image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - else: - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - image_prompt_embeds = self.image_proj_model(clip_image_embeds) - uncond_clip_image_embeds = self.image_encoder( - torch.zeros_like(clip_image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) - return image_prompt_embeds, uncond_image_prompt_embeds - - def set_ip_adapter(self): - unet = self.pipe.unet - attn_procs = {} - unet_sd = unet.state_dict() - - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim - if attn_processor_name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif attn_processor_name.startswith("up_blocks"): - block_id = int(attn_processor_name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif attn_processor_name.startswith("down_blocks"): - block_id = int(attn_processor_name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - if cross_attention_dim is None: - attn_procs[attn_processor_name] = AttnProcessor() - else: - layer_name = attn_processor_name.split(".processor")[0] - weights = { - "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], - "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], - } - attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) - attn_procs[attn_processor_name].load_state_dict(weights,strict=False) - - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) - attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) - attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) - attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) - - unet.set_attn_processor(attn_procs) - if hasattr(self.pipe, "controlnet"): - if isinstance(self.pipe.controlnet, MultiControlNetModel): - for controlnet in self.pipe.controlnet.nets: - controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - else: - self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - - - -class IPAdapterPlus_Lora_up(IPAdapter): - """IP-Adapter with fine-grained features""" - - def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): - self.rank = rank - super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) - - - def generate( - self, - pil_image=None, - clip_image_embeds=None, - prompt=None, - negative_prompt=None, - scale=1.0, - num_samples=4, - seed=None, - guidance_scale=7.5, - num_inference_steps=50, - **kwargs, - ): - self.set_scale(scale) - - if pil_image is not None: - num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) - else: - num_prompts = clip_image_embeds.size(0) - - if prompt is None: - prompt = "best quality, high quality" - if negative_prompt is None: - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( - pil_image=pil_image, clip_image=clip_image_embeds - ) - bs_embed, seq_len, _ = image_prompt_embeds.shape - image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) - image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) - uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) - - with torch.inference_mode(): - prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( - prompt, - device=self.device, - num_images_per_prompt=num_samples, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) - - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None - images = self.pipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - generator=generator, - **kwargs, - ).images - - return images - - - def init_proj(self): - image_proj_model = Resampler( - dim=self.pipe.unet.config.cross_attention_dim, - depth=4, - dim_head=64, - heads=12, - num_queries=self.num_tokens, - embedding_dim=self.image_encoder.config.hidden_size, - output_dim=self.pipe.unet.config.cross_attention_dim, - ff_mult=4, - ).to(self.device, dtype=torch.float16) - return image_proj_model - - @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - else: - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True @@ -752,55 +313,6 @@ class IPAdapterPlus_Lora_up(IPAdapter): uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) return image_prompt_embeds, uncond_image_prompt_embeds - def set_ip_adapter(self): - unet = self.pipe.unet - attn_procs = {} - unet_sd = unet.state_dict() - - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim - if attn_processor_name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif attn_processor_name.startswith("up_blocks"): - block_id = int(attn_processor_name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif attn_processor_name.startswith("down_blocks"): - block_id = int(attn_processor_name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - if cross_attention_dim is None: - attn_procs[attn_processor_name] = AttnProcessor() - else: - layer_name = attn_processor_name.split(".processor")[0] - weights = { - "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], - "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], - } - attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) - attn_procs[attn_processor_name].load_state_dict(weights,strict=False) - - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - - if "up_blocks" in attn_processor_name: - attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) - attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) - attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) - attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) - - - - unet.set_attn_processor(attn_procs) - if hasattr(self.pipe, "controlnet"): - if isinstance(self.pipe.controlnet, MultiControlNetModel): - for controlnet in self.pipe.controlnet.nets: - controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - else: - self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) - - class IPAdapterFull(IPAdapterPlus): """IP-Adapter with full features""" @@ -830,15 +342,12 @@ class IPAdapterPlusXL(IPAdapter): return image_proj_model @torch.inference_mode() - def get_image_embeds(self, pil_image=None, clip_image_embeds=None): - if pil_image is not None: - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] - clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] - else: - clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] image_prompt_embeds = self.image_proj_model(clip_image_embeds) uncond_clip_image_embeds = self.image_encoder( torch.zeros_like(clip_image), output_hidden_states=True @@ -893,7 +402,8 @@ class IPAdapterPlusXL(IPAdapter): prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) - generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None + generator = get_generator(seed, self.device) + images = self.pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, diff --git a/ip_adapter/ip_adapter_faceid.py b/ip_adapter/ip_adapter_faceid.py new file mode 100644 index 0000000000000000000000000000000000000000..fe98ad540648d429a3a227733c8607626fe0784e --- /dev/null +++ b/ip_adapter/ip_adapter_faceid.py @@ -0,0 +1,542 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from .utils import is_torch2_available, get_generator + +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +if is_torch2_available() and (not USE_DAFAULT_ATTN): + from .attention_processor_faceid import ( + LoRAAttnProcessor2_0 as LoRAAttnProcessor, + ) + from .attention_processor_faceid import ( + LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor, + ) +else: + from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from .resampler import PerceiverAttention, FeedForward + + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out + + +class IPAdapterFaceID: + def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds): + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlus: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=self.torch_dtype + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ProjPlusModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): + if isinstance(face_image, Image.Image): + pil_image = [face_image] + clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=False, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): + """SDXL""" + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=True, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + **kwargs, + ).images + + return images diff --git a/ip_adapter/ip_adapter_faceid_separate.py b/ip_adapter/ip_adapter_faceid_separate.py new file mode 100644 index 0000000000000000000000000000000000000000..80ca84c4fdb56e9b0dfb56195426427b31f07fb8 --- /dev/null +++ b/ip_adapter/ip_adapter_faceid_separate.py @@ -0,0 +1,547 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +if is_torch2_available() and (not USE_DAFAULT_ATTN): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, IPAttnProcessor +from .resampler import PerceiverAttention, FeedForward + + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out + + +class IPAdapterFaceID: + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, n_cond=1, torch_dtype=torch.float16): + self.device = device + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.n_cond = n_cond + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens*self.n_cond, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds): + + multi_face = False + if faceid_embeds.dim() == 3: + multi_face = True + b, n, c = faceid_embeds.shape + faceid_embeds = faceid_embeds.reshape(b*n, c) + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + if multi_face: + c = image_prompt_embeds.size(-1) + image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c) + + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlus: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=self.torch_dtype + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ProjPlusModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): + if isinstance(face_image, Image.Image): + pil_image = [face_image] + clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=False, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): + """SDXL""" + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=True, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + **kwargs, + ).images + + return images diff --git a/ip_adapter/resampler.py b/ip_adapter/resampler.py index 708740b4a3abbe08030e4428c39acd7438c1006d..24266671d02092438ae6576336a59659fef9c054 100644 --- a/ip_adapter/resampler.py +++ b/ip_adapter/resampler.py @@ -78,54 +78,6 @@ class PerceiverAttention(nn.Module): return self.to_out(out) -class CrossAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_k = nn.Linear(dim, inner_dim, bias=False) - self.to_v = nn.Linear(dim, inner_dim, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - - def forward(self, x, x2): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - x2 = self.norm2(x2) - - b, l, _ = x2.shape - - q = self.to_q(x) - k = self.to_k(x2) - v = self.to_v(x2) - - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - return self.to_out(out) - - class Resampler(nn.Module): def __init__( self, @@ -142,6 +94,7 @@ class Resampler(nn.Module): num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence ): super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) @@ -150,6 +103,16 @@ class Resampler(nn.Module): self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -162,11 +125,19 @@ class Resampler(nn.Module): ) def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) for attn, ff in self.layers: latents = attn(x, latents) + latents @@ -176,7 +147,6 @@ class Resampler(nn.Module): return self.norm_out(latents) - def masked_mean(t, *, dim, mask=None): if mask is None: return t.mean(dim=dim) diff --git a/ip_adapter/utils.py b/ip_adapter/utils.py index 9a105f3701c15e8d3bbf838d79bacc51e91d0696..6a273358585962fdf383d0bb7a0e1c654b4999b8 100644 --- a/ip_adapter/utils.py +++ b/ip_adapter/utils.py @@ -1,5 +1,93 @@ +import torch import torch.nn.functional as F +import numpy as np +from PIL import Image +attn_maps = {} +def hook_fn(name): + def forward_hook(module, input, output): + if hasattr(module.processor, "attn_map"): + attn_maps[name] = module.processor.attn_map + del module.processor.attn_map + return forward_hook + +def register_cross_attention_hook(unet): + for name, module in unet.named_modules(): + if name.split('.')[-1].startswith('attn2'): + module.register_forward_hook(hook_fn(name)) + + return unet + +def upscale(attn_map, target_size): + attn_map = torch.mean(attn_map, dim=0) + attn_map = attn_map.permute(1,0) + temp_size = None + + for i in range(0,5): + scale = 2 ** i + if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: + temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) + break + + assert temp_size is not None, "temp_size cannot is None" + + attn_map = attn_map.view(attn_map.shape[0], *temp_size) + + attn_map = F.interpolate( + attn_map.unsqueeze(0).to(dtype=torch.float32), + size=target_size, + mode='bilinear', + align_corners=False + )[0] + + attn_map = torch.softmax(attn_map, dim=0) + return attn_map +def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): + + idx = 0 if instance_or_negative else 1 + net_attn_maps = [] + + for name, attn_map in attn_maps.items(): + attn_map = attn_map.cpu() if detach else attn_map + attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() + attn_map = upscale(attn_map, image_size) + net_attn_maps.append(attn_map) + + net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) + + return net_attn_maps + +def attnmaps2images(net_attn_maps): + + #total_attn_scores = 0 + images = [] + + for attn_map in net_attn_maps: + attn_map = attn_map.cpu().numpy() + #total_attn_scores += attn_map.mean().item() + + normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 + normalized_attn_map = normalized_attn_map.astype(np.uint8) + #print("norm: ", normalized_attn_map.shape) + image = Image.fromarray(normalized_attn_map) + + #image = fix_save_attn_map(attn_map) + images.append(image) + + #print(total_attn_scores) + return images def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") + +def get_generator(seed, device): + + if seed is not None: + if isinstance(seed, list): + generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] + else: + generator = torch.Generator(device).manual_seed(seed) + else: + generator = None + + return generator \ No newline at end of file diff --git a/preprocess/humanparsing/run_parsing.py b/preprocess/humanparsing/run_parsing.py index db9a629eb6c16ecceb005d5610229cad91f71a6b..14028467468d280139329e1f197e63df886d1fb3 100644 --- a/preprocess/humanparsing/run_parsing.py +++ b/preprocess/humanparsing/run_parsing.py @@ -11,12 +11,12 @@ import torch class Parsing: def __init__(self, gpu_id: int): - self.gpu_id = gpu_id - torch.cuda.set_device(gpu_id) + # self.gpu_id = gpu_id + # torch.cuda.set_device(gpu_id) session_options = ort.SessionOptions() session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - session_options.add_session_config_entry('gpu_id', str(gpu_id)) + # session_options.add_session_config_entry('gpu_id', str(gpu_id)) self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'ckpt/humanparsing/parsing_atr.onnx'), sess_options=session_options, providers=['CPUExecutionProvider']) self.lip_session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'ckpt/humanparsing/parsing_lip.onnx'), diff --git a/preprocess/openpose/annotator/openpose/body.py b/preprocess/openpose/annotator/openpose/body.py index 7af77f0409f9aa5195e668a7eb782267ea0889c3..27012f28ad0e736f7e7bd877f35db3b64fe8bd9c 100644 --- a/preprocess/openpose/annotator/openpose/body.py +++ b/preprocess/openpose/annotator/openpose/body.py @@ -20,9 +20,9 @@ from .model import bodypose_model class Body(object): def __init__(self, model_path): self.model = bodypose_model() - if torch.cuda.is_available(): - self.model = self.model.cuda() - # print('cuda') + # if torch.cuda.is_available(): + # self.model = self.model.cuda() + # print('cuda') model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() diff --git a/preprocess/openpose/run_openpose.py b/preprocess/openpose/run_openpose.py index 37e8ee0d2411a218b0ffba01a190ab0e56dae06e..fa0ed1fe39e8c871726555f184f1dcc3c8dd51bc 100644 --- a/preprocess/openpose/run_openpose.py +++ b/preprocess/openpose/run_openpose.py @@ -1,6 +1,6 @@ import pdb -# import config +import config from pathlib import Path import sys @@ -28,12 +28,12 @@ import pdb class OpenPose: def __init__(self, gpu_id: int): - self.gpu_id = gpu_id - torch.cuda.set_device(gpu_id) + # self.gpu_id = gpu_id + # torch.cuda.set_device(gpu_id) self.preprocessor = OpenposeDetector() def __call__(self, input_image, resolution=384): - torch.cuda.set_device(self.gpu_id) + # torch.cuda.set_device(self.gpu_id) if isinstance(input_image, Image.Image): input_image = np.asarray(input_image) elif type(input_image) == str: diff --git a/requirements.txt b/requirements.txt index e0182cb3ac961882217447e9a53550f5562207b5..a9243fb4b897c3cc398fda901f58f4978a0eba28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,23 @@ -huggingface_hub==0.20.2 -accelerate==0.25.0 -torchmetrics==1.2.1 -tqdm==4.66.1 -transformers==4.36.2 -diffusers==0.25.0 -einops==0.7.0 -bitsandbytes==0.39.0 -scipy==1.11.1 -opencv-python -gradio==4.24.0 -fvcore -cloudpickle -omegaconf -pycocotools -basicsr -av +transformers==4.36.2 +torch==2.0.1 +torchvision==0.15.2 +torchaudio==2.0.2 +numpy==1.24.4 +scipy==1.10.1 +scikit-image==0.21.0 +opencv-python==4.7.0.72 +pillow==9.4.0 +diffusers==0.25.0 +transformers==4.36.2 +accelerate==0.26.1 +matplotlib==3.7.4 +tqdm==4.64.1 +config==0.5.1 +einops==0.7.0 onnxruntime==1.16.2 +basicsr +av +fvcore +cloudpickle +omegaconf +pycocotools \ No newline at end of file diff --git a/src/attentionhacked_tryon.py b/src/attentionhacked_tryon.py index 1e5123ca4e07786bcbb28769b0fc1323c6db3571..9947e7c878f9f1fd3701b939dfc243628fa54652 100644 --- a/src/attentionhacked_tryon.py +++ b/src/attentionhacked_tryon.py @@ -331,6 +331,7 @@ class BasicTransformerBlock(nn.Module): gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + #type2 modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1) curr_garment_feat_idx +=1 attn_output = self.attn1( @@ -345,6 +346,8 @@ class BasicTransformerBlock(nn.Module): elif self.use_ada_layer_norm_single: attn_output = gate_msa * attn_output + + #type2 hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states diff --git a/src/tryon_pipeline.py b/src/tryon_pipeline.py index 397ad18e462ac687858229fdab2a992ab5a722b7..78f22375101a3312a6e3f992126f92fa91839ed4 100644 --- a/src/tryon_pipeline.py +++ b/src/tryon_pipeline.py @@ -370,7 +370,6 @@ class StableDiffusionXLInpaintPipeline( "text_encoder_2", "image_encoder", "feature_extractor", - "unet_encoder", ] _callback_tensor_inputs = [ "latents", @@ -392,7 +391,6 @@ class StableDiffusionXLInpaintPipeline( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - unet_encoder: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, @@ -409,7 +407,6 @@ class StableDiffusionXLInpaintPipeline( tokenizer_2=tokenizer_2, unet=unet, image_encoder=image_encoder, - unet_encoder=unet_encoder, feature_extractor=feature_extractor, scheduler=scheduler, ) @@ -1722,8 +1719,8 @@ class StableDiffusionXLInpaintPipeline( ip_adapter_image, device, batch_size * num_images_per_prompt ) - #project outside for loop - image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype) + #project outside for loop + image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype) # 11. Denoising loop diff --git a/train_xl.py b/train_xl.py deleted file mode 100644 index 970bf47e3f50bcda97b5132eb39f407659df643d..0000000000000000000000000000000000000000 --- a/train_xl.py +++ /dev/null @@ -1,797 +0,0 @@ -import os -import random -import argparse -import json -import itertools -import torch -import torch.nn.functional as F -from torchvision import transforms -from PIL import Image -from transformers import CLIPImageProcessor -from accelerate import Accelerator -from accelerate.utils import ProjectConfiguration -from diffusers import AutoencoderKL, DDPMScheduler -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection - -from src.unet_hacked_tryon import UNet2DConditionModel -from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref -from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline - -from ip_adapter.ip_adapter import Resampler -from diffusers.utils.import_utils import is_xformers_available -from typing import Literal, Tuple,List -import torch.utils.data as data -import math -from tqdm.auto import tqdm -from diffusers.training_utils import compute_snr -import torchvision.transforms.functional as TF - - - -class VitonHDDataset(data.Dataset): - def __init__( - self, - dataroot_path: str, - phase: Literal["train", "test"], - order: Literal["paired", "unpaired"] = "paired", - size: Tuple[int, int] = (512, 384), - ): - super(VitonHDDataset, self).__init__() - self.dataroot = dataroot_path - self.phase = phase - self.height = size[0] - self.width = size[1] - self.size = size - - - self.norm = transforms.Normalize([0.5], [0.5]) - self.transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - self.transform2D = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] - ) - self.toTensor = transforms.ToTensor() - - with open( - os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r" - ) as file1: - data1 = json.load(file1) - - annotation_list = [ - # "colors", - # "textures", - "sleeveLength", - "neckLine", - "item", - ] - - self.annotation_pair = {} - for k, v in data1.items(): - for elem in v: - annotation_str = "" - for template in annotation_list: - for tag in elem["tag_info"]: - if ( - tag["tag_name"] == template - and tag["tag_category"] is not None - ): - annotation_str += tag["tag_category"] - annotation_str += " " - self.annotation_pair[elem["file_name"]] = annotation_str - - - self.order = order - - self.toTensor = transforms.ToTensor() - - im_names = [] - c_names = [] - dataroot_names = [] - - - if phase == "train": - filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") - else: - filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") - - with open(filename, "r") as f: - for line in f.readlines(): - if phase == "train": - im_name, _ = line.strip().split() - c_name = im_name - else: - if order == "paired": - im_name, _ = line.strip().split() - c_name = im_name - else: - im_name, c_name = line.strip().split() - - im_names.append(im_name) - c_names.append(c_name) - dataroot_names.append(dataroot_path) - - self.im_names = im_names - self.c_names = c_names - self.dataroot_names = dataroot_names - self.flip_transform = transforms.RandomHorizontalFlip(p=1) - self.clip_processor = CLIPImageProcessor() - def __getitem__(self, index): - c_name = self.c_names[index] - im_name = self.im_names[index] - # subject_txt = self.txt_preprocess['train']("shirt") - if c_name in self.annotation_pair: - cloth_annotation = self.annotation_pair[c_name] - else: - cloth_annotation = "shirts" - - cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name)) - - im_pil_big = Image.open( - os.path.join(self.dataroot, self.phase, "image", im_name) - ).resize((self.width,self.height)) - - image = self.transform(im_pil_big) - # load parsing image - - - mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)) - mask = self.toTensor(mask) - mask = mask[:1] - densepose_name = im_name - densepose_map = Image.open( - os.path.join(self.dataroot, self.phase, "image-densepose", densepose_name) - ) - pose_img = self.toTensor(densepose_map) # [-1,1] - - - - if self.phase == "train": - if random.random() > 0.5: - cloth = self.flip_transform(cloth) - mask = self.flip_transform(mask) - image = self.flip_transform(image) - pose_img = self.flip_transform(pose_img) - - - - if random.random()>0.5: - color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.3, saturation=0.5, hue=0.5) - fn_idx, b, c, s, h = transforms.ColorJitter.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue) - - image = TF.adjust_contrast(image, c) - image = TF.adjust_brightness(image, b) - image = TF.adjust_hue(image, h) - image = TF.adjust_saturation(image, s) - - cloth = TF.adjust_contrast(cloth, c) - cloth = TF.adjust_brightness(cloth, b) - cloth = TF.adjust_hue(cloth, h) - cloth = TF.adjust_saturation(cloth, s) - - - if random.random() > 0.5: - scale_val = random.uniform(0.8, 1.2) - image = transforms.functional.affine( - image, angle=0, translate=[0, 0], scale=scale_val, shear=0 - ) - mask = transforms.functional.affine( - mask, angle=0, translate=[0, 0], scale=scale_val, shear=0 - ) - pose_img = transforms.functional.affine( - pose_img, angle=0, translate=[0, 0], scale=scale_val, shear=0 - ) - - - - if random.random() > 0.5: - shift_valx = random.uniform(-0.2, 0.2) - shift_valy = random.uniform(-0.2, 0.2) - image = transforms.functional.affine( - image, - angle=0, - translate=[shift_valx * image.shape[-1], shift_valy * image.shape[-2]], - scale=1, - shear=0, - ) - mask = transforms.functional.affine( - mask, - angle=0, - translate=[shift_valx * mask.shape[-1], shift_valy * mask.shape[-2]], - scale=1, - shear=0, - ) - pose_img = transforms.functional.affine( - pose_img, - angle=0, - translate=[ - shift_valx * pose_img.shape[-1], - shift_valy * pose_img.shape[-2], - ], - scale=1, - shear=0, - ) - - - - - mask = 1-mask - - cloth_trim = self.clip_processor(images=cloth, return_tensors="pt").pixel_values - - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - im_mask = image * mask - - pose_img = self.norm(pose_img) - - - result = {} - result["c_name"] = c_name - result["image"] = image - result["cloth"] = cloth_trim - result["cloth_pure"] = self.transform(cloth) - result["inpaint_mask"] = 1-mask - result["im_mask"] = im_mask - result["caption"] = "model is wearing " + cloth_annotation - result["caption_cloth"] = "a photo of " + cloth_annotation - result["annotation"] = cloth_annotation - result["pose_img"] = pose_img - - - return result - - def __len__(self): - return len(self.im_names) - - - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--pretrained_model_name_or_path",type=str,default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",) - parser.add_argument("--pretrained_garmentnet_path",type=str,default="stabilityai/stable-diffusion-xl-base-1.0",required=False,help="Path to pretrained model or model identifier from huggingface.co/models.",) - parser.add_argument("--checkpointing_epoch",type=int,default=10,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),) - parser.add_argument("--pretrained_ip_adapter_path",type=str,default="ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin",help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",) - parser.add_argument("--image_encoder_path",type=str,default="ckpt/image_encoder",required=False,help="Path to CLIP image encoder",) - parser.add_argument("--gradient_checkpointing",action="store_true",help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",) - parser.add_argument("--width",type=int,default=768,) - parser.add_argument("--height",type=int,default=1024,) - parser.add_argument("--gradient_accumulation_steps",type=int,default=1,help="Number of updates steps to accumulate before performing a backward/update pass.",) - parser.add_argument("--logging_steps",type=int,default=1000,help=("Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"" training using `--resume_from_checkpoint`."),) - parser.add_argument("--output_dir",type=str,default="output",help="The output directory where the model predictions and checkpoints will be written.",) - parser.add_argument("--snr_gamma",type=float,default=None,help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. ""More details here: https://arxiv.org/abs/2303.09556.",) - parser.add_argument("--num_tokens",type=int,default=16,help=("IP adapter token nums"),) - parser.add_argument("--learning_rate",type=float,default=1e-5,help="Learning rate to use.",) - parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--train_batch_size", type=int, default=6, help="Batch size (per device) for the training dataloader.") - parser.add_argument("--test_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") - parser.add_argument("--num_train_epochs", type=int, default=130) - parser.add_argument("--max_train_steps",type=int,default=None,help="Total number of training steps to perform. If provided, overrides num_train_epochs.",) - parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") - parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.") - parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") - parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],help=("Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),) - parser.add_argument("--guidance_scale",type=float,default=2.0,) - parser.add_argument("--seed", type=int, default=42,) - parser.add_argument("--num_inference_steps",type=int,default=30,) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - parser.add_argument("--data_dir", type=str, default="/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando", help="For distributed training: local_rank") - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - return args - - - - - -def main(): - - - args = parse_args() - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) - accelerator = Accelerator( - mixed_precision=args.mixed_precision, - gradient_accumulation_steps=args.gradient_accumulation_steps, - project_config=accelerator_project_config, - ) - - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler",rescale_betas_zero_snr=True) - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,subfolder="vae",torch_dtype=torch.float16,) - unet_encoder = UNet2DConditionModel_ref.from_pretrained(args.pretrained_garmentnet_path, subfolder="unet") - unet_encoder.config.addition_embed_type = None - unet_encoder.config["addition_embed_type"] = None - image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) - - #customize unet start - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",low_cpu_mem_usage=False, device_map=None) - unet.config.encoder_hid_dim = image_encoder.config.hidden_size - unet.config.encoder_hid_dim_type = "ip_image_proj" - unet.config["encoder_hid_dim"] = image_encoder.config.hidden_size - unet.config["encoder_hid_dim_type"] = "ip_image_proj" - - - state_dict = torch.load(args.pretrained_ip_adapter_path, map_location="cpu") - - - adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) - adapter_modules.load_state_dict(state_dict["ip_adapter"],strict=True) - - #ip-adapter - image_proj_model = Resampler( - dim=image_encoder.config.hidden_size, - depth=4, - dim_head=64, - heads=20, - num_queries=args.num_tokens, - embedding_dim=image_encoder.config.hidden_size, - output_dim=unet.config.cross_attention_dim, - ff_mult=4, - ).to(accelerator.device, dtype=torch.float32) - - image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) - image_proj_model.requires_grad_(True) - - unet.encoder_hid_proj = image_proj_model - - conv_new = torch.nn.Conv2d( - in_channels=4+4+1+4, - out_channels=unet.conv_in.out_channels, - kernel_size=3, - padding=1, - ) - torch.nn.init.kaiming_normal_(conv_new.weight) - conv_new.weight.data = conv_new.weight.data * 0. - - conv_new.weight.data[:, :9] = unet.conv_in.weight.data - conv_new.bias.data = unet.conv_in.bias.data - - unet.conv_in = conv_new # replace conv layer in unet - unet.config['in_channels'] = 13 # update config - unet.config.in_channels = 13 # update config - #customize unet end - - - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - vae.to(accelerator.device) - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder_2.to(accelerator.device, dtype=weight_dtype) - image_encoder.to(accelerator.device, dtype=weight_dtype) - unet_encoder.to(accelerator.device, dtype=weight_dtype) - - - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - text_encoder_2.requires_grad_(False) - image_encoder.requires_grad_(False) - unet_encoder.requires_grad_(False) - unet.requires_grad_(True) - - - - - if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): - import xformers - - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - unet_encoder.enable_gradient_checkpointing() - unet.train() - - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - params_to_opt = itertools.chain(unet.parameters()) - - - optimizer = optimizer_class( - params_to_opt, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - - train_dataset = VitonHDDataset( - dataroot_path=args.data_dir, - phase="train", - order="paired", - size=(args.height, args.width), - ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - pin_memory=True, - shuffle=False, - batch_size=args.train_batch_size, - num_workers=16, - ) - test_dataset = VitonHDDataset( - dataroot_path=args.data_dir, - phase="test", - order="paired", - size=(args.height, args.width), - ) - test_dataloader = torch.utils.data.DataLoader( - test_dataset, - shuffle=False, - batch_size=args.test_batch_size, - num_workers=4, - ) - - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - - unet,image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader = accelerator.prepare(unet, image_proj_model,unet_encoder,image_encoder,optimizer,train_dataloader,test_dataloader) - initial_global_step = 0 - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Train! - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - global_step = 0 - first_epoch = 0 - train_loss=0.0 - for epoch in range(first_epoch, args.num_train_epochs): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet), accelerator.accumulate(image_proj_model): - if global_step % args.logging_steps == 0: - if accelerator.is_main_process: - with torch.no_grad(): - with torch.cuda.amp.autocast(): - unwrapped_unet= accelerator.unwrap_model(unet) - newpipe = TryonPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unwrapped_unet, - vae= vae, - scheduler=noise_scheduler, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - image_encoder=image_encoder, - unet_encoder = unet_encoder, - torch_dtype=torch.float16, - add_watermarker=False, - safety_checker=None, - ).to(accelerator.device) - with torch.no_grad(): - for sample in test_dataloader: - img_emb_list = [] - for i in range(sample['cloth'].shape[0]): - img_emb_list.append(sample['cloth'][i]) - - prompt = sample["caption"] - - num_prompts = sample['cloth'].shape[0] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - image_embeds = torch.cat(img_emb_list,dim=0) - - - with torch.inference_mode(): - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = newpipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt, - ) - - - prompt = sample["caption_cloth"] - negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" - - if not isinstance(prompt, List): - prompt = [prompt] * num_prompts - if not isinstance(negative_prompt, List): - negative_prompt = [negative_prompt] * num_prompts - - - with torch.inference_mode(): - ( - prompt_embeds_c, - _, - _, - _, - ) = newpipe.encode_prompt( - prompt, - num_images_per_prompt=1, - do_classifier_free_guidance=False, - negative_prompt=negative_prompt, - ) - - - - generator = torch.Generator(newpipe.device).manual_seed(args.seed) if args.seed is not None else None - images = newpipe( - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - num_inference_steps=args.num_inference_steps, - generator=generator, - strength = 1.0, - pose_img = sample['pose_img'], - text_embeds_cloth=prompt_embeds_c, - cloth = sample["cloth_pure"].to(accelerator.device), - mask_image=sample['inpaint_mask'], - image=(sample['image']+1.0)/2.0, - height=args.height, - width=args.width, - guidance_scale=args.guidance_scale, - ip_adapter_image = image_embeds, - )[0] - - for i in range(len(images)): - images[i].save(os.path.join(args.output_dir,str(global_step)+"_"+str(i)+"_"+"test.jpg")) - break - del unwrapped_unet - del newpipe - torch.cuda.empty_cache() - - - - pixel_values = batch["image"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = model_input * vae.config.scaling_factor - - masked_latents = vae.encode( - batch["im_mask"].reshape(batch["image"].shape).to(dtype=vae.dtype) - ).latent_dist.sample() - masked_latents = masked_latents * vae.config.scaling_factor - masks = batch["inpaint_mask"] - # resize the mask to latents shape as we concatenate the mask to the latents - mask = torch.stack( - [ - torch.nn.functional.interpolate(masks, size=(args.height // 8, args.width // 8)) - ] - ) - mask = mask.reshape(-1, 1, args.height // 8, args.width // 8) - - pose_map = vae.encode(batch["pose_img"].to(dtype=vae.dtype)).latent_dist.sample() - pose_map = pose_map * vae.config.scaling_factor - - # Sample noise that we'll add to the latents - noise = torch.randn_like(model_input) - - bsz = model_input.shape[0] - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device - ) - # Add noise to the latents according to the noise magnitude at each timestep - noisy_latents = noise_scheduler.add_noise(model_input, noise, timesteps) - latent_model_input = torch.cat([noisy_latents, mask,masked_latents,pose_map], dim=1) - - - text_input_ids = tokenizer( - batch['caption'], - max_length=tokenizer.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt" - ).input_ids - text_input_ids_2 = tokenizer_2( - batch['caption'], - max_length=tokenizer_2.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt" - ).input_ids - - encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True) - text_embeds = encoder_output.hidden_states[-2] - encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True) - pooled_text_embeds = encoder_output_2[0] - text_embeds_2 = encoder_output_2.hidden_states[-2] - encoder_hidden_states = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat - - - def compute_time_ids(original_size, crops_coords_top_left = (0,0)): - # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - target_size = (args.height, args.height) - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device) - return add_time_ids - - add_time_ids = torch.cat( - [compute_time_ids((args.height, args.height)) for i in range(bsz)] - ) - - img_emb_list = [] - for i in range(bsz): - img_emb_list.append(batch['cloth'][i]) - - image_embeds = torch.cat(img_emb_list,dim=0) - image_embeds = image_encoder(image_embeds, output_hidden_states=True).hidden_states[-2] - ip_tokens =image_proj_model(image_embeds) - - - - # add cond - unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} - unet_added_cond_kwargs["image_embeds"] = ip_tokens - - cloth_values = batch["cloth_pure"].to(accelerator.device,dtype=vae.dtype) - cloth_values = vae.encode(cloth_values).latent_dist.sample() - cloth_values = cloth_values * vae.config.scaling_factor - - - text_input_ids = tokenizer( - batch['caption_cloth'], - max_length=tokenizer.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt" - ).input_ids - text_input_ids_2 = tokenizer_2( - batch['caption_cloth'], - max_length=tokenizer_2.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt" - ).input_ids - - - encoder_output = text_encoder(text_input_ids.to(accelerator.device), output_hidden_states=True) - text_embeds_cloth = encoder_output.hidden_states[-2] - encoder_output_2 = text_encoder_2(text_input_ids_2.to(accelerator.device), output_hidden_states=True) - text_embeds_2_cloth = encoder_output_2.hidden_states[-2] - text_embeds_cloth = torch.concat([text_embeds_cloth, text_embeds_2_cloth], dim=-1) # concat - - - down,reference_features = unet_encoder(cloth_values,timesteps, text_embeds_cloth,return_dict=False) - reference_features = list(reference_features) - - noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states,added_cond_kwargs=unet_added_cond_kwargs,garment_features=reference_features).sample - - - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(model_input, noise, timesteps) - elif noise_scheduler.config.prediction_type == "sample": - # We set the target to latents here, but the model_pred will return the noise sample prediction. - target = model_input - # We will have to subtract the noise residual from the prediction to get the target sample. - model_pred = model_pred - noise - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - - if args.snr_gamma is None: - loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") - else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(noise_scheduler, timesteps) - if noise_scheduler.config.prediction_type == "v_prediction": - # Velocity objective requires that we add one to SNR values before we divide by them. - snr = snr + 1 - mse_loss_weights = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) - - loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = loss.mean() - - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps - - - # Backpropagate - accelerator.backward(loss) - - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(params_to_opt, 1.0) - - optimizer.step() - optimizer.zero_grad() - # Load scheduler, tokenizer and models. - progress_bar.update(1) - global_step += 1 - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - logs = {"step_loss": loss.detach().item()} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if global_step % args.checkpointing_epoch == 0: - if accelerator.is_main_process: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - unwrapped_unet = accelerator.unwrap_model( - unet, keep_fp32_wrapper=True - ) - pipeline = TryonPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=unwrapped_unet, - vae= vae, - scheduler=noise_scheduler, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - image_encoder=image_encoder, - unet_encoder=unet_encoder, - torch_dtype=torch.float16, - add_watermarker=False, - safety_checker=None, - ) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - pipeline.save_pretrained(save_path) - del pipeline - - -if __name__ == "__main__": - main() diff --git a/train_xl.sh b/train_xl.sh deleted file mode 100644 index 84b52ba4e5c6661fc4ded94de359e7657deee0ff..0000000000000000000000000000000000000000 --- a/train_xl.sh +++ /dev/null @@ -1 +0,0 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch train_xl.py --gradient_checkpointing --use_8bit_adam --output_dir=result --train_batch_size=6 --data_dir=/home/omnious/workspace/yisol/Dataset/VITON-HD/zalando \ No newline at end of file diff --git a/gradio_demo/utils_mask.py b/utils_mask.py similarity index 100% rename from gradio_demo/utils_mask.py rename to utils_mask.py diff --git a/vitonhd_test_tagged.json b/vitonhd_test_tagged.json deleted file mode 100644 index e68c99399bdea7911b500c901ca599804a11ebeb..0000000000000000000000000000000000000000 --- a/vitonhd_test_tagged.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ff196c077cf27ace24a1710fba53a2719606f60983f583aebcb61d0b3730a1dc -size 1803684 diff --git a/vitonhd_train_tagged.json b/vitonhd_train_tagged.json deleted file mode 100644 index 4d3ee968d6971eeae198f6fd960bfd4b99a3c92d..0000000000000000000000000000000000000000 --- a/vitonhd_train_tagged.json +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a277d878e33fa4247631ce8db5584739c7f8515183f8a476fd6e64a5f80c94f6 -size 24848224