release iChatApp
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .gitignore +27 -0
- README.md +161 -8
- README_CN.md +147 -0
- assets/arch1.png +3 -0
- assets/demo2.gif +3 -0
- assets/demo3.gif +3 -0
- assets/demo4.gif +3 -0
- assets/demo5.gif +3 -0
- assets/demo6.jpg +3 -0
- assets/gvlab_logo.png +3 -0
- assets/images/IMG3584.jpeg +3 -0
- assets/images/IMG3585.jpeg +3 -0
- assets/images/IMG3588.jpeg +3 -0
- assets/images/IMG3589.jpeg +3 -0
- assets/images/ultrakun.jpeg +3 -0
- assets/images/workspace.jpeg +3 -0
- assets/videos/iKun.mp4 +3 -0
- configs/big_lama_config.yaml +157 -0
- configs/med_config.json +21 -0
- configs/q2l_config.json +23 -0
- configs/swin/config_swinB_224.json +10 -0
- configs/swin/config_swinB_384.json +10 -0
- configs/swin/config_swinB_480.json +9 -0
- configs/swin/config_swinB_576.json +9 -0
- configs/swin/config_swinB_608.json +9 -0
- configs/tag2text_caption.yaml +33 -0
- iChat/__init__.py +1 -0
- iChat/chatbot/__init__.py +1 -0
- iChat/chatbot/chatbot.py +440 -0
- iChat/models/__init__.py +37 -0
- iChat/models/grit_model.py +46 -0
- iChat/models/grit_src/configs/Base.yaml +77 -0
- iChat/models/grit_src/configs/GRiT_B_DenseCap.yaml +20 -0
- iChat/models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml +23 -0
- iChat/models/grit_src/configs/GRiT_B_ObjectDet.yaml +20 -0
- iChat/models/grit_src/configs/GRiT_H_ObjectDet.yaml +21 -0
- iChat/models/grit_src/configs/GRiT_L_ObjectDet.yaml +20 -0
- iChat/models/grit_src/grit/__init__.py +7 -0
- iChat/models/grit_src/grit/config.py +50 -0
- iChat/models/grit_src/grit/custom_solver.py +88 -0
- iChat/models/grit_src/grit/data/custom_build_augmentation.py +44 -0
- iChat/models/grit_src/grit/data/custom_dataset_dataloader.py +250 -0
- iChat/models/grit_src/grit/data/custom_dataset_mapper.py +149 -0
- iChat/models/grit_src/grit/data/datasets/grit_coco.py +112 -0
- iChat/models/grit_src/grit/data/datasets/object365.py +111 -0
- iChat/models/grit_src/grit/data/datasets/vg.py +98 -0
- iChat/models/grit_src/grit/data/transforms/custom_augmentation_impl.py +52 -0
- iChat/models/grit_src/grit/data/transforms/custom_transform.py +115 -0
- iChat/models/grit_src/grit/evaluation/eval.py +156 -0
.gitattributes
CHANGED
@@ -32,3 +32,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compilation and distribution
|
2 |
+
__pycache__
|
3 |
+
_ext
|
4 |
+
*.pyc
|
5 |
+
*.pyd
|
6 |
+
*.so
|
7 |
+
*.dll
|
8 |
+
*.egg-info/
|
9 |
+
build/
|
10 |
+
dist/
|
11 |
+
wheels/
|
12 |
+
|
13 |
+
# Editor temporaries
|
14 |
+
*.swn
|
15 |
+
*.swo
|
16 |
+
*.swp
|
17 |
+
*~
|
18 |
+
|
19 |
+
# editor settings
|
20 |
+
.idea
|
21 |
+
.vscode
|
22 |
+
_darcs
|
23 |
+
|
24 |
+
# custom files
|
25 |
+
./image/
|
26 |
+
./tmp_files/
|
27 |
+
|
README.md
CHANGED
@@ -1,13 +1,166 @@
|
|
1 |
-
---
|
2 |
title: InternChat
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
-
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
title: InternChat
|
2 |
+
emoji: 🤖💬
|
3 |
+
colorFrom: indigo
|
4 |
+
colorTo: pink
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 3.28.1
|
7 |
+
app_file: iChatApp.py
|
8 |
pinned: false
|
9 |
license: apache-2.0
|
|
|
10 |
|
11 |
+
|
12 |
+
[[中文文档]](README_CN.md)
|
13 |
+
|
14 |
+
**The project is still under construction, we will continue to update it and welcome contributions/pull requests from the community.**
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
<p align="center"><img src="./assets/gvlab_logo.png" width="600"></p>
|
19 |
+
|
20 |
+
<a src="https://img.shields.io/discord/1099920215724277770?label=Discord&logo=discord" href="https://discord.gg/khWBFnCgAN">
|
21 |
+
<img src="https://img.shields.io/discord/1099920215724277770?label=Discord&logo=discord"> </a> | <a src="https://img.shields.io/badge/GPU%20Demo-Open-green?logo=alibabacloud" href="https://ichat.opengvlab.com">
|
22 |
+
<img src="https://img.shields.io/badge/Demo-Open-green?logo=alibabacloud"> </a> | <a src="https://img.shields.io/twitter/follow/opengvlab?style=social" href="https://twitter.com/opengvlab">
|
23 |
+
<img src="https://img.shields.io/twitter/follow/opengvlab?style=social"> </a>
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
# InternChat [[paper](https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/papers/ichat.pdf)]
|
28 |
+
|
29 |
+
|
30 |
+
<!-- ## Description -->
|
31 |
+
**InternChat**(short for **iChat**) is pointing-language-driven visual interactive system, allowing you to interact with ChatGPT by clicking, dragging and drawing using a pointing device. The name InternChat stands for **inter**action, **n**onverbal, and **chat**bots. Different from existing interactive systems that rely on pure language, by incorporating pointing instructions, iChat significantly improves the efficiency of communication between users and chatbots, as well as the accuracy of chatbots in vision-centric tasks, especially in complicated visual scenarios. Additionally, in iChat, an auxiliary control mechanism is used to improve the control capability of LLM, and a large vision-language model termed **Husky** is fine-tuned for high-quality multi-modal dialogue (impressing ChatGPT-3.5-turbo with **93.89% GPT-4 Quality**).
|
32 |
+
|
33 |
+
## Online Demo
|
34 |
+
[**InternChat**](https://ichat.opengvlab.com/) is online. Let's try it!
|
35 |
+
|
36 |
+
[**NOTE**] It is possible that you are waiting in a lengthy queue. You can clone our repo and run it with your private GPU.
|
37 |
+
|
38 |
+
https://github.com/OpenGVLab/InternChat/assets/13723743/3270b05f-0823-4f13-9966-4010fd855643
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
## Schedule
|
43 |
+
- [ ] Support Chinese
|
44 |
+
- [ ] Support MOSS
|
45 |
+
- [ ] More powerful foundation models based on [InternImage](https://github.com/OpenGVLab/InternImage) and [InternVideo](https://github.com/OpenGVLab/InternVideo)
|
46 |
+
- [ ] More accurate interactive experience
|
47 |
+
- [ ] Web Page & Code Generation
|
48 |
+
- [x] Support voice assistant
|
49 |
+
- [x] Support click interaction
|
50 |
+
- [x] Interactive image editing
|
51 |
+
- [x] Interactive image generation
|
52 |
+
- [x] Interactive visual question answering
|
53 |
+
- [x] Segment Anything
|
54 |
+
- [x] Image inpainting
|
55 |
+
- [x] Image caption
|
56 |
+
- [x] image matting
|
57 |
+
- [x] Optical character recognition
|
58 |
+
- [x] Action recognition
|
59 |
+
- [x] Video caption
|
60 |
+
- [x] Video dense caption
|
61 |
+
- [x] video highlight interpretation
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
## System Overview
|
66 |
+
<p align="center"><img src="./assets/arch1.png" alt="Logo"></p>
|
67 |
+
|
68 |
+
## 🎁 Major Features
|
69 |
+
<!--<!-- <p align="center"><img src="./assets/online_demo.gif" alt="Logo"></p> -->
|
70 |
+
<p align="center">(a) Remove the masked object</p>
|
71 |
+
<p align="center"><img src="./assets/demo2.gif" width="500"></p>
|
72 |
+
|
73 |
+
<p align="center">(b) Interactive image editing</center>
|
74 |
+
<p align="center"><img src="./assets/demo3.gif" width="500"></p>
|
75 |
+
|
76 |
+
<p align="center">(c) Image generation</p>
|
77 |
+
<p align="center"><img src="./assets/demo4.gif" align='justify' width="500"></p>
|
78 |
+
|
79 |
+
<p align="center">(d) Interactive visual question answer</p>
|
80 |
+
<p align="center"><img src="./assets/demo5.gif" align='justify' width="700"></p>
|
81 |
+
|
82 |
+
|
83 |
+
<p align="center">(e) Interactive image generation</p>
|
84 |
+
<p align="center"><img width="800" alt="image" src="https://github.com/OpenGVLab/InternChat/assets/8529570/2b0da08e-af86-453d-99e5-1327f93aa917"></p>
|
85 |
+
|
86 |
+
|
87 |
+
<p align="center">(f) Video highlight interpretation</p>
|
88 |
+
<p align="center"><img src="./assets/demo6.jpg" align='justify' width="500"></p>
|
89 |
+
|
90 |
+
<!-- ![alt]("./assets/demo5.gif" "title") -->
|
91 |
+
|
92 |
+
|
93 |
+
## 🛠️ Installation
|
94 |
+
|
95 |
+
### Basic requirements
|
96 |
+
|
97 |
+
- Linux
|
98 |
+
- Python 3.8+
|
99 |
+
- PyTorch 1.12+
|
100 |
+
- CUDA 11.6+
|
101 |
+
- GCC & G++ 5.4+
|
102 |
+
- GPU Memory >= 17G for loading basic tools (HuskyVQA, SegmentAnything, ImageOCRRecognition)
|
103 |
+
|
104 |
+
### Install Python dependencies
|
105 |
+
```shell
|
106 |
+
pip install -r requirements.txt
|
107 |
+
```
|
108 |
+
|
109 |
+
### Model zoo
|
110 |
+
Coming soon...
|
111 |
+
|
112 |
+
## 👨🏫 Get Started
|
113 |
+
Running the following shell can start a gradio service:
|
114 |
+
```shell
|
115 |
+
python -u iChatApp.py --load "HuskyVQA_cuda:0,SegmentAnything_cuda:0,ImageOCRRecognition_cuda:0" --port 3456
|
116 |
+
```
|
117 |
+
|
118 |
+
if you want to enable the voice assistant, please use `openssl` to generate the certificate:
|
119 |
+
```shell
|
120 |
+
openssl req -x509 -newkey rsa:4096 -keyout ./key.pem -out ./cert.pem -sha256 -days 365 -nodes
|
121 |
+
```
|
122 |
+
|
123 |
+
and then run:
|
124 |
+
```shell
|
125 |
+
python -u iChatApp.py --load "HuskyVQA_cuda:0,SegmentAnything_cuda:0,ImageOCRRecognition_cuda:0" --port 3456 --https
|
126 |
+
```
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
## 🎫 License
|
132 |
+
|
133 |
+
This project is released under the [Apache 2.0 license](LICENSE).
|
134 |
+
|
135 |
+
## 🖊️ Citation
|
136 |
+
|
137 |
+
If you find this project useful in your research, please consider cite:
|
138 |
+
```BibTeX
|
139 |
+
@misc{2023internchat,
|
140 |
+
title={InternChat: Solving Vision-Centric Tasks by Interacting with Chatbots Beyond Language},
|
141 |
+
author={Zhaoyang Liu and Yinan He and Wenhai Wang and Weiyun Wang and Yi Wang and Shoufa Chen and Qinglong Zhang and Yang Yang and Qingyun Li and Jiashuo Yu and Kunchang Li and Zhe Chen and Xue Yang and Xizhou Zhu and Yali Wang and Limin Wang and Ping Luo and Jifeng Dai and Yu Qiao},
|
142 |
+
howpublished = {\url{https://arxiv.org/abs/2305.05662}},
|
143 |
+
year={2023}
|
144 |
+
}
|
145 |
+
```
|
146 |
+
|
147 |
+
## 🤝 Acknowledgement
|
148 |
+
Thanks to the open source of the following projects:
|
149 |
+
|
150 |
+
[Hugging Face](https://github.com/huggingface)  
|
151 |
+
[LangChain](https://github.com/hwchase17/langchain)  
|
152 |
+
[TaskMatrix](https://github.com/microsoft/TaskMatrix)  
|
153 |
+
[SAM](https://github.com/facebookresearch/segment-anything)  
|
154 |
+
[Stable Diffusion](https://github.com/CompVis/stable-diffusion)  
|
155 |
+
[ControlNet](https://github.com/lllyasviel/ControlNet)  
|
156 |
+
[InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)  
|
157 |
+
[BLIP](https://github.com/salesforce/BLIP)  
|
158 |
+
[Latent Diffusion Models](https://github.com/CompVis/latent-diffusion)  
|
159 |
+
[EasyOCR](https://github.com/JaidedAI/EasyOCR)  
|
160 |
+
|
161 |
+
Welcome to discuss with us and continuously improve the user experience of InternChat.
|
162 |
+
|
163 |
+
WeChat QR Code
|
164 |
+
|
165 |
+
<p align="center"><img width="500" alt="image" src="https://github.com/OpenGVLab/InternChat/assets/8529570/881c231d-9049-4920-a22c-680f41f0f7ee"></p>
|
166 |
+
|
README_CN.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[English Document](README.md)]
|
2 |
+
|
3 |
+
**[NOTE] 该项目仍在建设中,我们将继续更新,并欢迎社区的贡献/拉取请求。**
|
4 |
+
|
5 |
+
<p align="center"><img src="./assets/gvlab_logo.png" width="600"></p>
|
6 |
+
|
7 |
+
<a src="https://img.shields.io/discord/1099920215724277770?label=Discord&logo=discord" href="https://discord.gg/khWBFnCgAN">
|
8 |
+
<img src="https://img.shields.io/discord/1099920215724277770?label=Discord&logo=discord"> </a> | <a src="https://img.shields.io/badge/GPU Demo-Open-green?logo=alibabacloud" href="https://ichat.opengvlab.com">
|
9 |
+
<img src="https://img.shields.io/badge/Demo-Open-green?logo=alibabacloud"> </a> | <a src="https://img.shields.io/twitter/follow/opengvlab?style=social" href="https://twitter.com/opengvlab">
|
10 |
+
<img src="https://img.shields.io/twitter/follow/opengvlab?style=social">
|
11 |
+
|
12 |
+
# InternChat [[论文](https://arxiv.org/pdf/2305.05662.pdf)]
|
13 |
+
<!-- ## 描述 -->
|
14 |
+
**InternChat**(简称 **iChat**)是一种基于指向语言驱动的视觉交互系统,允许您使用指向设备通过点击、拖动和绘制与 ChatGPT 进行互动。InternChat 的名称代表了 **inter**action(交互)、**n**onverbal(非语言)和 **chat**bots(聊天机器人)。与依赖纯语言的现有交互系统不同,通过整合指向指令,iChat 显著提高了用户与聊天机器人之间的沟通效率,以及聊天机器人在视觉为中心任务中的准确性,特别是在复杂的视觉场景中。此外,在 iChat 中,采用辅助控制机制来提高 LLM 的控制能力,并对一个大型视觉-语言模型 **Husky** 进行微调,以实现高质量的多模态对话(在ChatGPT-3.5-turbo评测中达到 **93.89% GPT-4 质量**)。
|
15 |
+
|
16 |
+
## 在线Demo
|
17 |
+
|
18 |
+
[注意] 可能会出现排队等待较长时间。您可以clone我们的仓库并使用您自己的GPU运行。
|
19 |
+
|
20 |
+
[**InternChat**已上线,尝试一下!](https://ichat.opengvlab.com)
|
21 |
+
|
22 |
+
|
23 |
+
https://github.com/OpenGVLab/InternChat/assets/13723743/3270b05f-0823-4f13-9966-4010fd855643
|
24 |
+
|
25 |
+
## Schedule
|
26 |
+
- [ ] 支持中文
|
27 |
+
- [ ] 支持 MOSS
|
28 |
+
- [ ] 基于 InternImage 和 InternVideo 的更强大的基础模型
|
29 |
+
- [ ] 更准确的交互体验
|
30 |
+
- [ ] 网页 & 代码生成
|
31 |
+
- [x] 支持语音助手
|
32 |
+
- [x] 支持点击交互
|
33 |
+
- [x] 交互式图像编辑
|
34 |
+
- [x] 交互式图像生成
|
35 |
+
- [x] 交互式视觉问答
|
36 |
+
- [x] Segment Anything模型
|
37 |
+
- [x] 图像修复
|
38 |
+
- [x] 图像描述
|
39 |
+
- [x] 图像抠图
|
40 |
+
- [x] 光学字符识别(OCR)
|
41 |
+
- [x] 动作识别
|
42 |
+
- [x] 视频描述
|
43 |
+
- [x] 视频密集描述
|
44 |
+
- [x] 视频高光时刻截取
|
45 |
+
|
46 |
+
## 系统概览
|
47 |
+
<p align="center"><img src="./assets/arch1.png" alt="Logo"></p>
|
48 |
+
|
49 |
+
## 🎁 主要特点
|
50 |
+
<!--<!-- <p align="center"><img src="./assets/online_demo.gif" alt="Logo"></p> -->
|
51 |
+
|
52 |
+
<p align="center">(a) 移除遮盖的对象</p>
|
53 |
+
<p align="center"><img src="./assets/demo2.gif" width="500"></p>
|
54 |
+
|
55 |
+
<p align="center">(b) 交互式图像编辑</center>
|
56 |
+
<p align="center"><img src="./assets/demo3.gif" width="500"></p>
|
57 |
+
|
58 |
+
<p align="center">(c) 图像生成</p>
|
59 |
+
<p align="center"><img src="./assets/demo4.gif" align='justify' width="500"></p>
|
60 |
+
|
61 |
+
<p align="center">(d) 交互式视觉问答</p>
|
62 |
+
<p align="center"><img src="./assets/demo5.gif" align='justify' width="700"></p>
|
63 |
+
|
64 |
+
<p align="center">(e) 交互式图像生成</p>
|
65 |
+
<p align="center"><img width="800" alt="image" src="https://github.com/OpenGVLab/InternChat/assets/8529570/2b0da08e-af86-453d-99e5-1327f93aa917"></p>
|
66 |
+
|
67 |
+
<p align="center">(f) 视频高光解释</p>
|
68 |
+
<p align="center"><img src="./assets/demo6.jpg" align='justify' width="500"></p>
|
69 |
+
|
70 |
+
<!-- ![alt]("./assets/demo5.gif" "title") -->
|
71 |
+
|
72 |
+
## 🛠️ 安装
|
73 |
+
|
74 |
+
### 基本要求
|
75 |
+
|
76 |
+
- Linux
|
77 |
+
- Python 3.8+
|
78 |
+
- PyTorch 1.12+
|
79 |
+
- CUDA 11.6+
|
80 |
+
- GCC & G++ 5.4+
|
81 |
+
- GPU 内存 >= 17G 用于加载基本工具 (HuskyVQA, SegmentAnything, ImageOCRRecognition)
|
82 |
+
|
83 |
+
### 安装Python的依赖项
|
84 |
+
```shell
|
85 |
+
pip install -r requirements.txt
|
86 |
+
```
|
87 |
+
|
88 |
+
### 模型库
|
89 |
+
|
90 |
+
即将推出...
|
91 |
+
|
92 |
+
## 👨🏫 运行指南
|
93 |
+
|
94 |
+
运行以下 shell 可启动一个 gradio 服务:
|
95 |
+
|
96 |
+
```shell
|
97 |
+
python -u iChatApp.py --load "HuskyVQA_cuda:0,SegmentAnything_cuda:0,ImageOCRRecognition_cuda:0" --port 3456
|
98 |
+
```
|
99 |
+
|
100 |
+
如果您想启用语音助手,请使用 openssl 生成证书:
|
101 |
+
|
102 |
+
```shell
|
103 |
+
openssl req -x509 -newkey rsa:4096 -keyout ./key.pem -out ./cert.pem -sha256 -days 365 -nodes
|
104 |
+
```
|
105 |
+
然后运行:
|
106 |
+
|
107 |
+
```shell
|
108 |
+
python -u iChatApp.py --load "HuskyVQA_cuda:0,SegmentAnything_cuda:0,ImageOCRRecognition_cuda:0" --port 3456 --https
|
109 |
+
```
|
110 |
+
|
111 |
+
|
112 |
+
## 🎫 许可
|
113 |
+
|
114 |
+
该项目根据[Apache 2.0 license](LICENSE)发布。
|
115 |
+
|
116 |
+
## 🖊️ 引用
|
117 |
+
|
118 |
+
如果您在研究中发现这个项目有用,请考虑引用:
|
119 |
+
```BibTeX
|
120 |
+
@misc{2023internchat,
|
121 |
+
title={InternChat: Solving Vision-Centric Tasks by Interacting with Chatbots Beyond Language},
|
122 |
+
author={Zhaoyang Liu and Yinan He and Wenhai Wang and Weiyun Wang and Yi Wang and Shoufa Chen and Qinglong Zhang and Yang Yang and Qingyun Li and Jiashuo Yu and Kunchang Li and Zhe Chen and Xue Yang and Xizhou Zhu and Yali Wang and Limin Wang and Ping Luo and Jifeng Dai and Yu Qiao},
|
123 |
+
howpublished = {\url{https://arxiv.org/abs/2305.05662}},
|
124 |
+
year={2023}
|
125 |
+
}
|
126 |
+
```
|
127 |
+
|
128 |
+
## 🤝 致谢
|
129 |
+
|
130 |
+
感谢以下开源项目:
|
131 |
+
|
132 |
+
[Hugging Face](https://github.com/huggingface)  
|
133 |
+
[LangChain](https://github.com/hwchase17/langchain)  
|
134 |
+
[TaskMatrix](https://github.com/microsoft/TaskMatrix)  
|
135 |
+
[SAM](https://github.com/facebookresearch/segment-anything)  
|
136 |
+
[Stable Diffusion](https://github.com/CompVis/stable-diffusion)  
|
137 |
+
[ControlNet](https://github.com/lllyasviel/ControlNet)  
|
138 |
+
[InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)  
|
139 |
+
[BLIP](https://github.com/salesforce/BLIP)  
|
140 |
+
[Latent Diffusion Models](https://github.com/CompVis/latent-diffusion)  
|
141 |
+
[EasyOCR](https://github.com/JaidedAI/EasyOCR)  
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
如果您在试用、运行、部署中有任何问题,欢迎加入我们的微信群讨论!如果您对项目有任何的想法和建议,欢迎加入我们的微信群讨论!
|
146 |
+
|
147 |
+
<p align="center"><img width="500" alt="image" src="https://github.com/OpenGVLab/InternChat/assets/8529570/881c231d-9049-4920-a22c-680f41f0f7ee"></p>
|
assets/arch1.png
ADDED
Git LFS Details
|
assets/demo2.gif
ADDED
Git LFS Details
|
assets/demo3.gif
ADDED
Git LFS Details
|
assets/demo4.gif
ADDED
Git LFS Details
|
assets/demo5.gif
ADDED
Git LFS Details
|
assets/demo6.jpg
ADDED
Git LFS Details
|
assets/gvlab_logo.png
ADDED
Git LFS Details
|
assets/images/IMG3584.jpeg
ADDED
Git LFS Details
|
assets/images/IMG3585.jpeg
ADDED
Git LFS Details
|
assets/images/IMG3588.jpeg
ADDED
Git LFS Details
|
assets/images/IMG3589.jpeg
ADDED
Git LFS Details
|
assets/images/ultrakun.jpeg
ADDED
Git LFS Details
|
assets/images/workspace.jpeg
ADDED
Git LFS Details
|
assets/videos/iKun.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52152a27a2c42c4dec12575c50dd6dced8db89ba041f5f90596aeb228b213e0a
|
3 |
+
size 1055241
|
configs/big_lama_config.yaml
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
run_title: b18_ffc075_batch8x15
|
2 |
+
training_model:
|
3 |
+
kind: default
|
4 |
+
visualize_each_iters: 1000
|
5 |
+
concat_mask: true
|
6 |
+
store_discr_outputs_for_vis: true
|
7 |
+
losses:
|
8 |
+
l1:
|
9 |
+
weight_missing: 0
|
10 |
+
weight_known: 10
|
11 |
+
perceptual:
|
12 |
+
weight: 0
|
13 |
+
adversarial:
|
14 |
+
kind: r1
|
15 |
+
weight: 10
|
16 |
+
gp_coef: 0.001
|
17 |
+
mask_as_fake_target: true
|
18 |
+
allow_scale_mask: true
|
19 |
+
feature_matching:
|
20 |
+
weight: 100
|
21 |
+
resnet_pl:
|
22 |
+
weight: 30
|
23 |
+
weights_path: ${env:TORCH_HOME}
|
24 |
+
|
25 |
+
optimizers:
|
26 |
+
generator:
|
27 |
+
kind: adam
|
28 |
+
lr: 0.001
|
29 |
+
discriminator:
|
30 |
+
kind: adam
|
31 |
+
lr: 0.0001
|
32 |
+
visualizer:
|
33 |
+
key_order:
|
34 |
+
- image
|
35 |
+
- predicted_image
|
36 |
+
- discr_output_fake
|
37 |
+
- discr_output_real
|
38 |
+
- inpainted
|
39 |
+
rescale_keys:
|
40 |
+
- discr_output_fake
|
41 |
+
- discr_output_real
|
42 |
+
kind: directory
|
43 |
+
outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
|
44 |
+
location:
|
45 |
+
data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
|
46 |
+
out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
|
47 |
+
tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
|
48 |
+
data:
|
49 |
+
batch_size: 15
|
50 |
+
val_batch_size: 2
|
51 |
+
num_workers: 3
|
52 |
+
train:
|
53 |
+
indir: ${location.data_root_dir}/train
|
54 |
+
out_size: 256
|
55 |
+
mask_gen_kwargs:
|
56 |
+
irregular_proba: 1
|
57 |
+
irregular_kwargs:
|
58 |
+
max_angle: 4
|
59 |
+
max_len: 200
|
60 |
+
max_width: 100
|
61 |
+
max_times: 5
|
62 |
+
min_times: 1
|
63 |
+
box_proba: 1
|
64 |
+
box_kwargs:
|
65 |
+
margin: 10
|
66 |
+
bbox_min_size: 30
|
67 |
+
bbox_max_size: 150
|
68 |
+
max_times: 3
|
69 |
+
min_times: 1
|
70 |
+
segm_proba: 0
|
71 |
+
segm_kwargs:
|
72 |
+
confidence_threshold: 0.5
|
73 |
+
max_object_area: 0.5
|
74 |
+
min_mask_area: 0.07
|
75 |
+
downsample_levels: 6
|
76 |
+
num_variants_per_mask: 1
|
77 |
+
rigidness_mode: 1
|
78 |
+
max_foreground_coverage: 0.3
|
79 |
+
max_foreground_intersection: 0.7
|
80 |
+
max_mask_intersection: 0.1
|
81 |
+
max_hidden_area: 0.1
|
82 |
+
max_scale_change: 0.25
|
83 |
+
horizontal_flip: true
|
84 |
+
max_vertical_shift: 0.2
|
85 |
+
position_shuffle: true
|
86 |
+
transform_variant: distortions
|
87 |
+
dataloader_kwargs:
|
88 |
+
batch_size: ${data.batch_size}
|
89 |
+
shuffle: true
|
90 |
+
num_workers: ${data.num_workers}
|
91 |
+
val:
|
92 |
+
indir: ${location.data_root_dir}/val
|
93 |
+
img_suffix: .png
|
94 |
+
dataloader_kwargs:
|
95 |
+
batch_size: ${data.val_batch_size}
|
96 |
+
shuffle: false
|
97 |
+
num_workers: ${data.num_workers}
|
98 |
+
visual_test:
|
99 |
+
indir: ${location.data_root_dir}/korean_test
|
100 |
+
img_suffix: _input.png
|
101 |
+
pad_out_to_modulo: 32
|
102 |
+
dataloader_kwargs:
|
103 |
+
batch_size: 1
|
104 |
+
shuffle: false
|
105 |
+
num_workers: ${data.num_workers}
|
106 |
+
generator:
|
107 |
+
kind: ffc_resnet
|
108 |
+
input_nc: 4
|
109 |
+
output_nc: 3
|
110 |
+
ngf: 64
|
111 |
+
n_downsampling: 3
|
112 |
+
n_blocks: 18
|
113 |
+
add_out_act: sigmoid
|
114 |
+
init_conv_kwargs:
|
115 |
+
ratio_gin: 0
|
116 |
+
ratio_gout: 0
|
117 |
+
enable_lfu: false
|
118 |
+
downsample_conv_kwargs:
|
119 |
+
ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
|
120 |
+
ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
|
121 |
+
enable_lfu: false
|
122 |
+
resnet_conv_kwargs:
|
123 |
+
ratio_gin: 0.75
|
124 |
+
ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
|
125 |
+
enable_lfu: false
|
126 |
+
discriminator:
|
127 |
+
kind: pix2pixhd_nlayer
|
128 |
+
input_nc: 3
|
129 |
+
ndf: 64
|
130 |
+
n_layers: 4
|
131 |
+
evaluator:
|
132 |
+
kind: default
|
133 |
+
inpainted_key: inpainted
|
134 |
+
integral_kind: ssim_fid100_f1
|
135 |
+
trainer:
|
136 |
+
kwargs:
|
137 |
+
gpus: -1
|
138 |
+
accelerator: ddp
|
139 |
+
max_epochs: 200
|
140 |
+
gradient_clip_val: 1
|
141 |
+
log_gpu_memory: None
|
142 |
+
limit_train_batches: 25000
|
143 |
+
val_check_interval: ${trainer.kwargs.limit_train_batches}
|
144 |
+
log_every_n_steps: 1000
|
145 |
+
precision: 32
|
146 |
+
terminate_on_nan: false
|
147 |
+
check_val_every_n_epoch: 1
|
148 |
+
num_sanity_val_steps: 8
|
149 |
+
limit_val_batches: 1000
|
150 |
+
replace_sampler_ddp: false
|
151 |
+
checkpoint_kwargs:
|
152 |
+
verbose: true
|
153 |
+
save_top_k: 5
|
154 |
+
save_last: true
|
155 |
+
period: 1
|
156 |
+
monitor: val_ssim_fid100_f1_total_mean
|
157 |
+
mode: max
|
configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
configs/q2l_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 4,
|
15 |
+
"num_hidden_layers": 2,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30522,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true,
|
21 |
+
"add_tag_cross_attention": false
|
22 |
+
}
|
23 |
+
|
configs/swin/config_swinB_224.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
|
3 |
+
"vision_width": 1024,
|
4 |
+
"image_res": 224,
|
5 |
+
"window_size": 7,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"depths": [ 2, 2, 18, 2 ],
|
8 |
+
"num_heads": [ 4, 8, 16, 32 ]
|
9 |
+
}
|
10 |
+
|
configs/swin/config_swinB_384.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
|
3 |
+
"vision_width": 1024,
|
4 |
+
"image_res": 384,
|
5 |
+
"window_size": 12,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"depths": [ 2, 2, 18, 2 ],
|
8 |
+
"num_heads": [ 4, 8, 16, 32 ]
|
9 |
+
}
|
10 |
+
|
configs/swin/config_swinB_480.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
|
3 |
+
"vision_width": 1024,
|
4 |
+
"image_res": 480,
|
5 |
+
"window_size": 15,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"depths": [ 2, 2, 18, 2 ],
|
8 |
+
"num_heads": [ 4, 8, 16, 32 ]
|
9 |
+
}
|
configs/swin/config_swinB_576.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
|
3 |
+
"vision_width": 1024,
|
4 |
+
"image_res": 576,
|
5 |
+
"window_size": 18,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"depths": [ 2, 2, 18, 2 ],
|
8 |
+
"num_heads": [ 4, 8, 16, 32 ]
|
9 |
+
}
|
configs/swin/config_swinB_608.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
|
3 |
+
"vision_width": 1024,
|
4 |
+
"image_res": 608,
|
5 |
+
"window_size": 19,
|
6 |
+
"embed_dim": 128,
|
7 |
+
"depths": [ 2, 2, 18, 2 ],
|
8 |
+
"num_heads": [ 4, 8, 16, 32 ]
|
9 |
+
}
|
configs/tag2text_caption.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/home/notebook/data/group/projects/tagging/caption/datasets/public/coco/'
|
2 |
+
|
3 |
+
ann_root: 'dataset/caption_dataset'
|
4 |
+
coco_gt_root: 'dataset/caption_dataset'
|
5 |
+
|
6 |
+
pretrained: '/home/notebook/code/personal/S9049611/BLIP/output/pretrain_caption_tagtotext_v2_bert_asl'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
vit: 'swin_b'
|
10 |
+
vit_grad_ckpt: False
|
11 |
+
vit_ckpt_layer: 0
|
12 |
+
|
13 |
+
batch_size: 35
|
14 |
+
init_lr: 5e-6
|
15 |
+
|
16 |
+
image_size: 384
|
17 |
+
|
18 |
+
# generation configs
|
19 |
+
max_length: 20
|
20 |
+
min_length: 5
|
21 |
+
num_beams: 3
|
22 |
+
prompt: 'a picture of '
|
23 |
+
|
24 |
+
# optimizer
|
25 |
+
weight_decay: 0.05
|
26 |
+
min_lr: 0
|
27 |
+
max_epoch: 10
|
28 |
+
|
29 |
+
text_pretrain: 'bert'
|
30 |
+
|
31 |
+
class_num: 3429
|
32 |
+
threshold: 0.7
|
33 |
+
|
iChat/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .models import *
|
iChat/chatbot/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .chatbot import ConversationBot
|
iChat/chatbot/chatbot.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import uuid
|
6 |
+
import shutil
|
7 |
+
import whisper
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from langchain.agents.initialize import initialize_agent
|
13 |
+
from langchain.agents.tools import Tool
|
14 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
15 |
+
from langchain.llms.openai import OpenAI
|
16 |
+
|
17 |
+
from ..models import *
|
18 |
+
from iGPT.models.utils import gen_new_name
|
19 |
+
|
20 |
+
GLOBAL_SEED=1912
|
21 |
+
|
22 |
+
|
23 |
+
'''
|
24 |
+
INTERN_CHAT_PREFIX = """InternChat is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. InternChat is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
25 |
+
|
26 |
+
InternChat is able to process and understand large amounts of text and images. As a language model, InternChat can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and InternChat can invoke different tools to indirectly understand pictures. When talking about images, InternChat is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, InternChat is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. InternChat is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
|
27 |
+
|
28 |
+
Human may provide new figures to InternChat with a description. The description helps InternChat to understand this image, but InternChat should use tools to finish following tasks, rather than directly imagine from the description.
|
29 |
+
|
30 |
+
Overall, InternChat is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
31 |
+
|
32 |
+
|
33 |
+
TOOLS:
|
34 |
+
------
|
35 |
+
|
36 |
+
InternChat has access to the following tools:"""
|
37 |
+
|
38 |
+
INTERN_CHAT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
39 |
+
|
40 |
+
```
|
41 |
+
Thought: Do I need to use a tool? Yes
|
42 |
+
Action: the action to take, should be one of [{tool_names}]
|
43 |
+
Action Input: the input to the action
|
44 |
+
Observation: the result of the action
|
45 |
+
```
|
46 |
+
|
47 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
48 |
+
|
49 |
+
```
|
50 |
+
Thought: Do I need to use a tool? No
|
51 |
+
{ai_prefix}: [your response here]
|
52 |
+
```
|
53 |
+
"""
|
54 |
+
|
55 |
+
INTERN_CHAT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if it does not exist.
|
56 |
+
You will remember to provide the image file name loyally if it's provided in the last tool observation.
|
57 |
+
|
58 |
+
Begin!
|
59 |
+
|
60 |
+
Previous conversation history:
|
61 |
+
{chat_history}
|
62 |
+
|
63 |
+
New input: {input}
|
64 |
+
Since InternChat is a text language model, InternChat must use tools to observe images rather than imagination.
|
65 |
+
The thoughts and observations are only visible for InternChat, InternChat should remember to repeat important information in the final response for Human.
|
66 |
+
Thought: Do I need to use a tool? {agent_scratchpad} Let's think step by step.
|
67 |
+
"""
|
68 |
+
|
69 |
+
INTERN_CHAT_PREFIX_CN = """InternChat 旨在能够协助完成范围广泛的文本和视觉相关任务,从回答简单的问题到提供对广泛主题的深入解释和讨论。 InternChat 能够根据收到的输入生成类似人类的文本,使其能够进行听起来自然的对话,并提供连贯且与手头主题相关的响应。
|
70 |
+
|
71 |
+
InternChat 能够处理和理解大量文本和图像。作为一种语言模型,InternChat 不能直接读取图像,但它有一系列工具来完成不同的视觉任务。每张图片都会有一个文件名,格式为“image/xxx.png”,InternChat可以调用不同的工具来间接理解图片。在谈论图片时,InternChat 对文件名的要求非常严格,绝不会伪造不存在的文件。在使用工具生成新的图像文件时,InternChat也知道图像可能与用户需求不一样,会使用其他视觉问答工具或描述工具来观察真实图像。 InternChat 能够按顺序使用工具,并且忠于工具观察输出,而不是伪造图像内容和图像文件名。如果生成新图像,它将记得提供上次工具观察的文件名。
|
72 |
+
|
73 |
+
Human 可能会向 InternChat 提供带有描述的新图形。描述帮助 InternChat 理解这个图像,但 InternChat 应该使用工具来完成以下任务,而不是直接从描述中想象。有些工具将会返回英文描述,但你对用户的聊天应当采用中文。
|
74 |
+
|
75 |
+
总的来说,InternChat 是一个强大的可视化对话辅助工具,可以帮助处理范围广泛的任务,并提供关于范围广泛的主题的有价值的见解和信息。
|
76 |
+
|
77 |
+
工具列表:
|
78 |
+
------
|
79 |
+
|
80 |
+
InternChat 可以使用这些工具:"""
|
81 |
+
|
82 |
+
INTERN_CHAT_FORMAT_INSTRUCTIONS_CN = """用户使用中文和你进行聊天,但是工具的参数应当使用英文。如果要调用工具,你必须遵循如下格式:
|
83 |
+
|
84 |
+
```
|
85 |
+
Thought: Do I need to use a tool? Yes
|
86 |
+
Action: the action to take, should be one of [{tool_names}]
|
87 |
+
Action Input: the input to the action
|
88 |
+
Observation: the result of the action
|
89 |
+
```
|
90 |
+
|
91 |
+
当你不再需要继续调用工具,而是对观察结果进行总结回复时,你必须使用如下格式:
|
92 |
+
|
93 |
+
|
94 |
+
```
|
95 |
+
Thought: Do I need to use a tool? No
|
96 |
+
{ai_prefix}: [your response here]
|
97 |
+
```
|
98 |
+
"""
|
99 |
+
|
100 |
+
INTERN_CHAT_SUFFIX_CN = """你对文件名的正确性非常严格,而且永远不会伪造不存在的文件。
|
101 |
+
|
102 |
+
开始!
|
103 |
+
|
104 |
+
因为InternChat是一个文本语言模型,必须使用工具去观察图片而不是依靠想象。
|
105 |
+
推理想法和观察结果只对InternChat可见,需要记得在最终回复时把重要的信息重复给用户,你只能给用户返回中文句子。我们一步一步思考。在你使用工具时,工具的参数只能是英文。
|
106 |
+
|
107 |
+
聊天历史:
|
108 |
+
{chat_history}
|
109 |
+
|
110 |
+
新输入: {input}
|
111 |
+
Thought: Do I need to use a tool? {agent_scratchpad}
|
112 |
+
"""
|
113 |
+
'''
|
114 |
+
|
115 |
+
|
116 |
+
VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
117 |
+
|
118 |
+
Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
|
119 |
+
|
120 |
+
Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
|
121 |
+
|
122 |
+
Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
123 |
+
|
124 |
+
|
125 |
+
TOOLS:
|
126 |
+
------
|
127 |
+
|
128 |
+
Visual ChatGPT has access to the following tools:"""
|
129 |
+
|
130 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
131 |
+
|
132 |
+
```
|
133 |
+
Thought: Do I need to use a tool? Yes
|
134 |
+
Action: the action to take, should be one of [{tool_names}]
|
135 |
+
Action Input: the input to the action
|
136 |
+
Observation: the result of the action
|
137 |
+
```
|
138 |
+
|
139 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
140 |
+
|
141 |
+
```
|
142 |
+
Thought: Do I need to use a tool? No
|
143 |
+
{ai_prefix}: [your response here]
|
144 |
+
```
|
145 |
+
"""
|
146 |
+
|
147 |
+
VISUAL_CHATGPT_SUFFIX = """You are very strict to the filename correctness and will never fake a file name if it does not exist.
|
148 |
+
You will remember to provide the image file name loyally if it's provided in the last tool observation.
|
149 |
+
|
150 |
+
Begin!
|
151 |
+
|
152 |
+
Previous conversation history:
|
153 |
+
{chat_history}
|
154 |
+
|
155 |
+
New input: {input}
|
156 |
+
Since Visual ChatGPT is a text language model, Visual ChatGPT must use tools to observe images rather than imagination.
|
157 |
+
The thoughts and observations are only visible for Visual ChatGPT, Visual ChatGPT should remember to repeat important information in the final response for Human.
|
158 |
+
Thought: Do I need to use a tool? {agent_scratchpad} Let's think step by step.
|
159 |
+
"""
|
160 |
+
|
161 |
+
VISUAL_CHATGPT_PREFIX_CN = """Visual ChatGPT 旨在能够协助完成范围广泛的文本和视觉相关任务,从回答简单的问题到提供对广泛主题的深入解释和讨论。 Visual ChatGPT 能够根据收到的输入生成类似人类的文本,使其能够进行听起来自然的对话,并提供连贯且与手头主题相关的响应。
|
162 |
+
|
163 |
+
Visual ChatGPT 能够处理和理解大量文本和图像。作为一种语言模型,Visual ChatGPT 不能直接读取图像,但它有一系列工具来完成不同的视觉任务。每张图片都会有一个文件名,格式为“image/xxx.png”,Visual ChatGPT可以调用不同的工具来间接理解图片。在谈论图片时,Visual ChatGPT 对文件名的要求非常严格,绝不会伪造不���在的文件。在使用工具生成新的图像文件时,Visual ChatGPT也知道图像可能与用户需求不一样,会使用其他视觉问答工具或描述工具来观察真实图像。 Visual ChatGPT 能够按顺序使用工具,并且忠于工具观察输出,而不是伪造图像内容和图像文件名。如果生成新图像,它将记得提供上次工具观察的文件名。
|
164 |
+
|
165 |
+
Human 可能会向 Visual ChatGPT 提供带有描述的新图形。描述帮助 Visual ChatGPT 理解这个图像,但 Visual ChatGPT 应该使用工具来完成以下任务,而不是直接从描述中想象。有些工具将会返回英文描述,但你对用户的聊天应当采用中文。
|
166 |
+
|
167 |
+
总的来说,Visual ChatGPT 是一个强大的可视化对话辅助工具,可以帮助处理范围广泛的任务,并提供关于范围广泛的主题的有价值的见解和信息。
|
168 |
+
|
169 |
+
工具列表:
|
170 |
+
------
|
171 |
+
|
172 |
+
Visual ChatGPT 可以使用这些工具:"""
|
173 |
+
|
174 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS_CN = """用户使用中文和你进行聊天,但是工具的参数应当使用英文。如果要调用工具,你必须遵循如下格式:
|
175 |
+
|
176 |
+
```
|
177 |
+
Thought: Do I need to use a tool? Yes
|
178 |
+
Action: the action to take, should be one of [{tool_names}]
|
179 |
+
Action Input: the input to the action
|
180 |
+
Observation: the result of the action
|
181 |
+
```
|
182 |
+
|
183 |
+
当你不再需要继续调用工具,而是对观察结果进行总结回复时,你必须使用如下格式:
|
184 |
+
|
185 |
+
|
186 |
+
```
|
187 |
+
Thought: Do I need to use a tool? No
|
188 |
+
{ai_prefix}: [your response here]
|
189 |
+
```
|
190 |
+
"""
|
191 |
+
|
192 |
+
VISUAL_CHATGPT_SUFFIX_CN = """你对文件名的正确性非常严格,而且永远不会伪造不存在的文件。
|
193 |
+
|
194 |
+
开始!
|
195 |
+
|
196 |
+
因为Visual ChatGPT是一个文本语言模型,必须使用工具去观察图片而不是依靠想象。
|
197 |
+
推理想法和观察结果只对Visual ChatGPT可见,需要记得在最终回复时把重要的信息重复给用户,你只能给用户返回中文句子。我们一步一步思考。在你使用工具时,工具的参数只能是英文。
|
198 |
+
|
199 |
+
聊天历史:
|
200 |
+
{chat_history}
|
201 |
+
|
202 |
+
新输入: {input}
|
203 |
+
Thought: Do I need to use a tool? {agent_scratchpad}
|
204 |
+
"""
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
209 |
+
if history_memory is None or len(history_memory) == 0:
|
210 |
+
return history_memory
|
211 |
+
tokens = history_memory.split()
|
212 |
+
n_tokens = len(tokens)
|
213 |
+
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
|
214 |
+
if n_tokens < keep_last_n_words:
|
215 |
+
return history_memory
|
216 |
+
paragraphs = history_memory.split('\n')
|
217 |
+
last_n_tokens = n_tokens
|
218 |
+
while last_n_tokens >= keep_last_n_words:
|
219 |
+
last_n_tokens -= len(paragraphs[0].split(' '))
|
220 |
+
paragraphs = paragraphs[1:]
|
221 |
+
return '\n' + '\n'.join(paragraphs)
|
222 |
+
|
223 |
+
|
224 |
+
class ConversationBot:
|
225 |
+
def __init__(self, load_dict):
|
226 |
+
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
227 |
+
print(f"Initializing VisualChatGPT, load_dict={load_dict}")
|
228 |
+
if 'ImageCaptioning' not in load_dict:
|
229 |
+
raise ValueError("You have to load ImageCaptioning as a basic function for i-GPT")
|
230 |
+
# if 'SegmentAnything' not in load_dict:
|
231 |
+
# raise ValueError("You have to load SegmentAnything as a basic function for i-GPT")
|
232 |
+
|
233 |
+
self.models = {}
|
234 |
+
self.uploaded_image_filename = None
|
235 |
+
# self.segmented_image_filename = None
|
236 |
+
self.history_mask = None
|
237 |
+
self.load_dict = load_dict
|
238 |
+
# self.llm = None
|
239 |
+
# Load Basic Foundation Models
|
240 |
+
for class_name, device in load_dict.items():
|
241 |
+
self.models[class_name] = globals()[class_name](device=device)
|
242 |
+
# self.models['models'] = self.models
|
243 |
+
|
244 |
+
# Load Template Foundation Models
|
245 |
+
for class_name, module in globals().items():
|
246 |
+
if getattr(module, 'template_model', False):
|
247 |
+
template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
|
248 |
+
loaded_names = set([type(e).__name__ for e in self.models.values()])
|
249 |
+
if template_required_names.issubset(loaded_names):
|
250 |
+
self.models[class_name] = globals()[class_name](
|
251 |
+
**{name: self.models[name] for name in template_required_names})
|
252 |
+
# elif 'models' in template_required_names:
|
253 |
+
# self.models[class_name] = globals()[class_name](
|
254 |
+
# **{name: self.models[name] for name in template_required_names})
|
255 |
+
|
256 |
+
self.tools = []
|
257 |
+
for instance in self.models.values():
|
258 |
+
for e in dir(instance):
|
259 |
+
if e.startswith('inference'):
|
260 |
+
func = getattr(instance, e)
|
261 |
+
self.tools.append(Tool(name=func.name, description=func.description, func=func))
|
262 |
+
self.llm = None
|
263 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
264 |
+
# self.first_init=True
|
265 |
+
self.audio_model = None
|
266 |
+
|
267 |
+
def init_agent(self):
|
268 |
+
self.memory.clear() #clear previous history
|
269 |
+
self.reset()
|
270 |
+
self.llm = OpenAI(temperature=0)
|
271 |
+
self.agent = initialize_agent(
|
272 |
+
self.tools,
|
273 |
+
self.llm,
|
274 |
+
agent="conversational-react-description",
|
275 |
+
verbose=True,
|
276 |
+
memory=self.memory,
|
277 |
+
return_intermediate_steps=True,
|
278 |
+
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
|
279 |
+
'suffix': VISUAL_CHATGPT_SUFFIX}, )
|
280 |
+
|
281 |
+
def run_text(self, text, state):
|
282 |
+
# print(f'text = {text}')
|
283 |
+
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
284 |
+
try:
|
285 |
+
print(f'text = {text}')
|
286 |
+
res = self.agent({"input": text.strip()})
|
287 |
+
print('ab'* 30)
|
288 |
+
print(res['output'])
|
289 |
+
print('cd'* 30)
|
290 |
+
except Exception as err:
|
291 |
+
# Human_prompt = text
|
292 |
+
# self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + ' AI: ' + AI_prompt
|
293 |
+
state += [(text, 'I can not understand your instruction. Could you provide more information?')]
|
294 |
+
print(err)
|
295 |
+
return state, state
|
296 |
+
|
297 |
+
res['output'] = res['output'].replace("\\", "/")
|
298 |
+
# response = re.sub('(tmp_files/[-\w]*.[png|mp4])', lambda m: f'![](file={m.group(0)})*{m.group(0)}*', res['output'])
|
299 |
+
|
300 |
+
# print("res['output'] = ", res['output'])
|
301 |
+
# response = re.sub('(tmp_files/[-\w]*.(png|mp4))', replace_path, res['output'])
|
302 |
+
pattern = re.compile('(image/[-\\w]*.(png|mp4))')
|
303 |
+
out_filenames = pattern.findall(res['output'])
|
304 |
+
response = res['output']
|
305 |
+
state = state + [(text, response)]
|
306 |
+
for f in out_filenames:
|
307 |
+
state = state + [(None, f'{f[0]} is as following: ')]
|
308 |
+
state = state + [(None, (f[0], ))]
|
309 |
+
# if len(out_filenames) > 1:
|
310 |
+
# state = state + [(None, (out_filenames[-1][0], ))]
|
311 |
+
# print('out_filename[-1][0] = ', out_filenames[-1][0])
|
312 |
+
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
|
313 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
314 |
+
return state, state
|
315 |
+
|
316 |
+
def run_audio(self, audio_path, state):
|
317 |
+
print(f'audio_path = {audio_path}')
|
318 |
+
if self.audio_model is None:
|
319 |
+
self.audio_model = whisper.load_model("small").to('cuda:0')
|
320 |
+
text = self.audio_model.transcribe(audio_path)["text"]
|
321 |
+
res = self.run_text(text, state)
|
322 |
+
print(f"\nProcessed run_audio, Input transcribed audio: {text}\nCurrent state: {state}\n"
|
323 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
324 |
+
return res[0], res[1]
|
325 |
+
|
326 |
+
def upload_image(self, image, state, txt):
|
327 |
+
self.reset()
|
328 |
+
img = image['image']
|
329 |
+
image_filename = os.path.join('image/', f"{str(uuid.uuid4())[:6]}.png")
|
330 |
+
self.uploaded_image_filename=image_filename
|
331 |
+
img = img.convert('RGB')
|
332 |
+
img.save(image_filename, "PNG")
|
333 |
+
# print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
334 |
+
# let some foundation models preprocess image
|
335 |
+
# NEED_PREPROCESSING_LIST = ["SegmentAnything", "ImageOCRRecognition"]
|
336 |
+
# for model_name in NEED_PREPROCESSING_LIST:
|
337 |
+
# if model_name in self.models.keys():
|
338 |
+
# self.models[model_name].preprocess(np.array(img), image_filename)
|
339 |
+
|
340 |
+
description = self.models['ImageCaptioning'].inference(image_filename)
|
341 |
+
# description = 'Debug'
|
342 |
+
|
343 |
+
Human_prompt = f'\nHuman: provide a figure named {image_filename}. The description is: {description}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
344 |
+
AI_prompt = "Received. "
|
345 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + ' AI: ' + AI_prompt
|
346 |
+
state = state + [(f"![](file={image_filename})*{image_filename}*", AI_prompt)]
|
347 |
+
print(f"\nProcessed upload_image, Input image: {image_filename}\nCurrent state: {state}\n"
|
348 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
349 |
+
return state, state, f'{txt} {image_filename} ', gr.update(visible=True), gr.update(visible=True)
|
350 |
+
|
351 |
+
def upload_video(self, video_path, state, txt):
|
352 |
+
# self.cur_file = video_path
|
353 |
+
vid_name = os.path.basename(video_path)
|
354 |
+
# vid_name = gen_new_name(vid_name, '', vid_name.split('.')[-1])
|
355 |
+
new_video_path = os.path.join('./image/', f"{str(uuid.uuid4())[:6]}.mp4")
|
356 |
+
new_video_path = gen_new_name(new_video_path, '', vid_name.split('.')[-1])
|
357 |
+
shutil.copy(video_path, new_video_path)
|
358 |
+
|
359 |
+
if "VideoCaption" in self.models.keys():
|
360 |
+
description = self.models['VideoCaption'].inference(new_video_path)
|
361 |
+
else:
|
362 |
+
description = 'A video.'
|
363 |
+
Human_prompt = f'\nHuman: provide a video named {new_video_path}. The description is: {description}. This information helps you to understand this video, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
364 |
+
AI_prompt = f"Received video: {new_video_path} "
|
365 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
366 |
+
# state = state + [(f"![](file={new_video_path})*{new_video_path}*", AI_prompt)]
|
367 |
+
# state = state + [(f"![](file={video_path})*{new_video_path}*", AI_prompt)]
|
368 |
+
state = state + [((new_video_path, ), AI_prompt)]
|
369 |
+
# print('exists = ', os.path.exists("./tmp_files/1e7f_f4236666_tmp.mp4"))
|
370 |
+
print(f"\nProcessed upload_video, Input video: {new_video_path}\nCurrent state: {state}\n"
|
371 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
372 |
+
return state, state, f'{txt} {new_video_path} '
|
373 |
+
|
374 |
+
def blend_mask(self, img, mask):
|
375 |
+
mask = mask.astype(np.uint8)
|
376 |
+
transparency_ratio = mask.astype(np.float32) / 3
|
377 |
+
transparency_ratio = transparency_ratio[:, :, np.newaxis]
|
378 |
+
mask = mask[:, :, np.newaxis] * 255
|
379 |
+
mask= mask.repeat(3, axis=2)
|
380 |
+
mask[:,:,0] = 0
|
381 |
+
mask[:,:,2] = 0
|
382 |
+
new_img_arr = np.array(img) * (1 - transparency_ratio) + mask * transparency_ratio
|
383 |
+
new_img_arr = np.clip(new_img_arr, 0, 255).astype(np.uint8)
|
384 |
+
# print(new_img_arr.shape)
|
385 |
+
return Image.fromarray(new_img_arr)
|
386 |
+
|
387 |
+
def process_image(self, image, state):
|
388 |
+
img = Image.open(self.uploaded_image_filename).convert('RGB')
|
389 |
+
# img = image['image'].convert('RGB')
|
390 |
+
mask = image['mask'].convert('L')
|
391 |
+
mask = np.array(mask, dtype=np.uint8)
|
392 |
+
|
393 |
+
Human_prompt="Please process this image based on given mask."
|
394 |
+
if self.uploaded_image_filename is None:
|
395 |
+
AI_prompt = "Please upload an image for processing."
|
396 |
+
state += [(Human_prompt, AI_prompt)]
|
397 |
+
return state, state, None
|
398 |
+
if mask.sum() == 0:
|
399 |
+
AI_prompt = "You can click the image in the right and ask me some questions."
|
400 |
+
state += [(Human_prompt, AI_prompt)]
|
401 |
+
return state, state, image['image']
|
402 |
+
|
403 |
+
if self.history_mask is None:
|
404 |
+
self.history_mask = mask
|
405 |
+
else:
|
406 |
+
self.history_mask = np.logical_or(self.history_mask, mask)
|
407 |
+
|
408 |
+
if 'SegmentAnything' in self.models.keys():
|
409 |
+
self.models['SegmentAnything'].clicked_region = self.history_mask
|
410 |
+
if 'ImageOCRRecognition' in self.models.keys():
|
411 |
+
self.models['ImageOCRRecognition'].clicked_region = mask
|
412 |
+
|
413 |
+
# self.models['SegmentAnything'].mask = self.history_mask
|
414 |
+
# history_mask = self.history_mask.astype(np.uint8) * 255
|
415 |
+
res_mask = self.models['SegmentAnything'].segment_by_mask(self.history_mask)
|
416 |
+
|
417 |
+
img = self.blend_mask(img, res_mask)
|
418 |
+
|
419 |
+
AI_prompt = f"I have finished processing. Now, you can ask me some questions."
|
420 |
+
state = state + [(Human_prompt, AI_prompt)]
|
421 |
+
# AI_prompt = f"Received. I found {ocr_text} in this position. The sgemented figure is named {seg_filename}."
|
422 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + ' AI: ' + AI_prompt
|
423 |
+
# state = state + [(Human_prompt, f"![](file={seg_filename})*{AI_prompt}*")]
|
424 |
+
# print()
|
425 |
+
print(f"\nProcessed run_image, Input image: {self.uploaded_image_filename}\nCurrent state: {state}\n"
|
426 |
+
f"Current Memory: {self.agent.memory.buffer}")
|
427 |
+
return state, state, img
|
428 |
+
|
429 |
+
def reset(self, clear_history_memory=False):
|
430 |
+
print('reset the model cache.')
|
431 |
+
NEED_RESET_LIST = ['SegmentAnything', 'ImageOCRRecognition']
|
432 |
+
for model_name in NEED_RESET_LIST:
|
433 |
+
if model_name in self.models.keys():
|
434 |
+
self.models[model_name].reset()
|
435 |
+
|
436 |
+
self.history_mask = None
|
437 |
+
self.uploaded_image_filename = None
|
438 |
+
if clear_history_memory:
|
439 |
+
self.agent.memory.clear()
|
440 |
+
return None
|
iChat/models/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .image import (MaskFormer, ImageEditing, InstructPix2Pix, \
|
2 |
+
# Text2Image, ImageCaptioning, Image2Canny, CannyText2Image, \
|
3 |
+
# Image2Line, LineText2Image, Image2Hed, HedText2Image, Image2Scribble, \
|
4 |
+
# ScribbleText2Image, Image2Pose, PoseText2Image, SegText2Image, \
|
5 |
+
# Image2Depth, DepthText2Image, Image2Normal, NormalText2Image, \
|
6 |
+
# VisualQuestionAnswering, InfinityOutPainting, \
|
7 |
+
# SegmentAnything, InpaintMaskedAnything, ExtractMaskedAnything, \
|
8 |
+
# ReplaceMaskedAnything, ImageOCRRecognition)
|
9 |
+
|
10 |
+
from .husky import HuskyVQA
|
11 |
+
|
12 |
+
from .video import (ActionRecognition, DenseCaption, VideoCaption,
|
13 |
+
Summarization, GenerateTikTokVideo)
|
14 |
+
|
15 |
+
from .lang import SimpleLanguageModel
|
16 |
+
|
17 |
+
from .inpainting import LDMInpainting
|
18 |
+
|
19 |
+
# __all__ = [
|
20 |
+
# 'MaskFormer', 'ImageEditing', 'InstructPix2Pix', \
|
21 |
+
# 'Text2Image', 'ImageCaptioning', 'Image2Canny', 'CannyText2Image', \
|
22 |
+
# 'Image2Line', 'LineText2Image', 'Image2Hed', 'HedText2Image', \
|
23 |
+
# 'Image2Scribble', 'ScribbleText2Image', 'Image2Pose', 'PoseText2Image', \
|
24 |
+
# 'SegText2Image', 'Image2Depth', 'DepthText2Image', 'Image2Normal', \
|
25 |
+
# 'NormalText2Image', 'VisualQuestionAnswering', 'InfinityOutPainting', \
|
26 |
+
# 'SegmentAnything', 'InpaintMaskedAnything', 'ExtractMaskedAnything', \
|
27 |
+
# 'ReplaceMaskedAnything', 'ImageOCRRecognition', "SimpleLanguageModel", \
|
28 |
+
# 'ActionRecognition', 'DenseCaption', 'VideoCaption', 'Summarization', \
|
29 |
+
# 'GenerateTikTokVideo'
|
30 |
+
# ]
|
31 |
+
|
32 |
+
__all__ = [
|
33 |
+
'HuskyVQA', "SimpleLanguageModel", 'GenerateTikTokVideo', \
|
34 |
+
'LDMInpainting',
|
35 |
+
'ActionRecognition', 'DenseCaption', 'VideoCaption', 'Summarization'
|
36 |
+
]
|
37 |
+
|
iChat/models/grit_model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from .grit_src.image_dense_captions import image_caption_api, init_demo, dense_pred_to_caption, dense_pred_to_caption_only_name
|
5 |
+
from detectron2.data.detection_utils import read_image
|
6 |
+
|
7 |
+
class DenseCaptioning():
|
8 |
+
def __init__(self, device):
|
9 |
+
self.device = device
|
10 |
+
self.demo = None
|
11 |
+
|
12 |
+
|
13 |
+
def initialize_model(self):
|
14 |
+
self.demo = init_demo(self.device)
|
15 |
+
|
16 |
+
def image_dense_caption_debug(self, image_src):
|
17 |
+
dense_caption = """
|
18 |
+
1. the broccoli is green, [0, 0, 333, 325];
|
19 |
+
2. a piece of broccoli, [0, 147, 143, 324];
|
20 |
+
3. silver fork on plate, [4, 547, 252, 612];
|
21 |
+
"""
|
22 |
+
return dense_caption
|
23 |
+
|
24 |
+
def image_dense_caption(self, image_src):
|
25 |
+
dense_caption = image_caption_api(image_src, self.device)
|
26 |
+
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
27 |
+
print("Step2, Dense Caption:\n")
|
28 |
+
print(dense_caption)
|
29 |
+
print('\033[1;35m' + '*' * 100 + '\033[0m')
|
30 |
+
return dense_caption
|
31 |
+
|
32 |
+
def run_caption_api(self,image_src):
|
33 |
+
img = read_image(image_src, format="BGR")
|
34 |
+
print(img.shape)
|
35 |
+
predictions, visualized_output = self.demo.run_on_image(img)
|
36 |
+
new_caption = dense_pred_to_caption_only_name(predictions)
|
37 |
+
return new_caption
|
38 |
+
|
39 |
+
def run_caption_tensor(self,img):
|
40 |
+
# img = read_image(image_src, format="BGR")
|
41 |
+
# print(img.shape)
|
42 |
+
predictions, visualized_output = self.demo.run_on_image(img)
|
43 |
+
new_caption = dense_pred_to_caption_only_name(predictions)
|
44 |
+
return new_caption
|
45 |
+
|
46 |
+
|
iChat/models/grit_src/configs/Base.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
META_ARCHITECTURE: "GRiT"
|
3 |
+
MASK_ON: True
|
4 |
+
PROPOSAL_GENERATOR:
|
5 |
+
NAME: "CenterNet"
|
6 |
+
FPN:
|
7 |
+
IN_FEATURES: ["layer3", "layer4", "layer5"]
|
8 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
9 |
+
PIXEL_STD: [58.395, 57.12, 57.375]
|
10 |
+
ROI_HEADS:
|
11 |
+
NAME: GRiTROIHeadsAndTextDecoder
|
12 |
+
IN_FEATURES: ["p3", "p4", "p5"]
|
13 |
+
IOU_THRESHOLDS: [0.6]
|
14 |
+
NUM_CLASSES: 1
|
15 |
+
SCORE_THRESH_TEST: 0.02
|
16 |
+
NMS_THRESH_TEST: 0.5
|
17 |
+
OBJECT_FEAT_POOLER_RES: 14
|
18 |
+
ROI_BOX_CASCADE_HEAD:
|
19 |
+
IOUS: [0.6, 0.7, 0.8]
|
20 |
+
ROI_BOX_HEAD:
|
21 |
+
NAME: "FastRCNNConvFCHead"
|
22 |
+
NUM_FC: 2
|
23 |
+
POOLER_RESOLUTION: 7
|
24 |
+
CLS_AGNOSTIC_BBOX_REG: True
|
25 |
+
MULT_PROPOSAL_SCORE: True
|
26 |
+
ROI_MASK_HEAD:
|
27 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
28 |
+
NUM_CONV: 4
|
29 |
+
POOLER_RESOLUTION: 14
|
30 |
+
CLS_AGNOSTIC_MASK: True
|
31 |
+
CENTERNET:
|
32 |
+
NUM_CLASSES: 1
|
33 |
+
REG_WEIGHT: 1.
|
34 |
+
NOT_NORM_REG: True
|
35 |
+
ONLY_PROPOSAL: True
|
36 |
+
WITH_AGN_HM: True
|
37 |
+
INFERENCE_TH: 0.0001
|
38 |
+
PRE_NMS_TOPK_TRAIN: 4000
|
39 |
+
POST_NMS_TOPK_TRAIN: 2000
|
40 |
+
PRE_NMS_TOPK_TEST: 1000
|
41 |
+
POST_NMS_TOPK_TEST: 256
|
42 |
+
NMS_TH_TRAIN: 0.9
|
43 |
+
NMS_TH_TEST: 0.9
|
44 |
+
POS_WEIGHT: 0.5
|
45 |
+
NEG_WEIGHT: 0.5
|
46 |
+
IGNORE_HIGH_FP: 0.85
|
47 |
+
DATASETS:
|
48 |
+
TRAIN: ("coco_2017_train",)
|
49 |
+
TEST: ("coco_2017_val",)
|
50 |
+
DATALOADER:
|
51 |
+
SAMPLER_TRAIN: "MultiDatasetSampler"
|
52 |
+
DATASET_RATIO: [1]
|
53 |
+
DATASET_INPUT_SIZE: [1024]
|
54 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0]]
|
55 |
+
FILTER_EMPTY_ANNOTATIONS: False
|
56 |
+
NUM_WORKERS: 8
|
57 |
+
TEST:
|
58 |
+
DETECTIONS_PER_IMAGE: 256
|
59 |
+
SOLVER:
|
60 |
+
LR_SCHEDULER_NAME: "WarmupCosineLR"
|
61 |
+
CHECKPOINT_PERIOD: 10000
|
62 |
+
WARMUP_ITERS: 1000
|
63 |
+
WARMUP_FACTOR: 0.001
|
64 |
+
USE_CUSTOM_SOLVER: True
|
65 |
+
OPTIMIZER: "ADAMW"
|
66 |
+
MAX_ITER: 180000
|
67 |
+
IMS_PER_BATCH: 64
|
68 |
+
BASE_LR: 0.00008
|
69 |
+
VIT_LAYER_DECAY: True
|
70 |
+
CLIP_GRADIENTS:
|
71 |
+
ENABLED: True
|
72 |
+
INPUT:
|
73 |
+
FORMAT: RGB
|
74 |
+
CUSTOM_AUG: EfficientDetResizeCrop
|
75 |
+
TRAIN_SIZE: 640
|
76 |
+
USE_ACT_CHECKPOINT: True
|
77 |
+
VERSION: 2
|
iChat/models/grit_src/configs/GRiT_B_DenseCap.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap"
|
5 |
+
MASK_ON: False
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("vg_train",)
|
17 |
+
TEST: ("vg_test",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap"
|
iChat/models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet", "DenseCap"]
|
4 |
+
TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: False
|
8 |
+
BEAM_SIZE: 1
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train", "vg_train")
|
17 |
+
TEST: ("coco_2017_test-dev",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_RATIO: [1, 1]
|
20 |
+
DATASET_BS: 2
|
21 |
+
DATASET_INPUT_SIZE: [1024, 1024]
|
22 |
+
DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]]
|
23 |
+
OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet"
|
iChat/models/grit_src/configs/GRiT_B_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone
|
12 |
+
VIT_LAYERS: 12
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.7
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 2
|
20 |
+
OUTPUT_DIR: "./output/GRiT_B_ObjectDet"
|
iChat/models/grit_src/configs/GRiT_H_ObjectDet.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_huge
|
12 |
+
VIT_LAYERS: 32
|
13 |
+
SOLVER:
|
14 |
+
MAX_ITER: 135000
|
15 |
+
VIT_LAYER_DECAY_RATE: 0.9
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("GRiT_coco2017_train",)
|
18 |
+
TEST: ("coco_2017_val",)
|
19 |
+
DATALOADER:
|
20 |
+
DATASET_BS: 1
|
21 |
+
OUTPUT_DIR: "./output/GRiT_H_ObjectDet"
|
iChat/models/grit_src/configs/GRiT_L_ObjectDet.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: "Base.yaml"
|
2 |
+
MODEL:
|
3 |
+
TRAIN_TASK: ["ObjectDet"]
|
4 |
+
TEST_TASK: "ObjectDet"
|
5 |
+
MASK_ON: True
|
6 |
+
ROI_HEADS:
|
7 |
+
SOFT_NMS_ENABLED: True
|
8 |
+
BEAM_SIZE: 3
|
9 |
+
WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth"
|
10 |
+
BACKBONE:
|
11 |
+
NAME: build_vit_fpn_backbone_large
|
12 |
+
VIT_LAYERS: 24
|
13 |
+
SOLVER:
|
14 |
+
VIT_LAYER_DECAY_RATE: 0.8
|
15 |
+
DATASETS:
|
16 |
+
TRAIN: ("GRiT_coco2017_train",)
|
17 |
+
TEST: ("coco_2017_val",)
|
18 |
+
DATALOADER:
|
19 |
+
DATASET_BS: 1
|
20 |
+
OUTPUT_DIR: "./output/GRiT_L_ObjectDet"
|
iChat/models/grit_src/grit/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .modeling.meta_arch import grit
|
2 |
+
from .modeling.roi_heads import grit_roi_heads
|
3 |
+
from .modeling.backbone import vit
|
4 |
+
|
5 |
+
from .data.datasets import object365
|
6 |
+
from .data.datasets import vg
|
7 |
+
from .data.datasets import grit_coco
|
iChat/models/grit_src/grit/config.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
def add_grit_config(cfg):
|
5 |
+
_C = cfg
|
6 |
+
|
7 |
+
_C.MODEL.BEAM_SIZE = 1
|
8 |
+
_C.MODEL.TRAIN_TASK = ["ObjectDet", "DenseCap"]
|
9 |
+
_C.MODEL.TEST_TASK = "DenseCap" # This can be varied if the model is jointly trained on multiple tasks
|
10 |
+
|
11 |
+
_C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
|
12 |
+
_C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False
|
13 |
+
|
14 |
+
_C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
|
15 |
+
_C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14
|
16 |
+
_C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
|
17 |
+
|
18 |
+
# Backbones
|
19 |
+
_C.MODEL.VIT_LAYERS = 12
|
20 |
+
|
21 |
+
# Text Decoder
|
22 |
+
_C.TEXT_DECODER = CN()
|
23 |
+
_C.TEXT_DECODER.VOCAB_SIZE = 30522
|
24 |
+
_C.TEXT_DECODER.HIDDEN_SIZE = 768
|
25 |
+
_C.TEXT_DECODER.NUM_LAYERS = 6
|
26 |
+
_C.TEXT_DECODER.ATTENTION_HEADS = 12
|
27 |
+
_C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4
|
28 |
+
|
29 |
+
# Multi-dataset dataloader
|
30 |
+
_C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
|
31 |
+
_C.DATALOADER.DATASET_BS = 1
|
32 |
+
_C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024]
|
33 |
+
_C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)]
|
34 |
+
_C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)]
|
35 |
+
_C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333]
|
36 |
+
|
37 |
+
_C.SOLVER.USE_CUSTOM_SOLVER = True
|
38 |
+
_C.SOLVER.OPTIMIZER = 'ADAMW'
|
39 |
+
_C.SOLVER.VIT_LAYER_DECAY = True
|
40 |
+
_C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7
|
41 |
+
|
42 |
+
_C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop'
|
43 |
+
_C.INPUT.TRAIN_SIZE = 1024
|
44 |
+
_C.INPUT.TEST_SIZE = 1024
|
45 |
+
_C.INPUT.SCALE_RANGE = (0.1, 2.)
|
46 |
+
# 'default' for fixed short / long edge
|
47 |
+
_C.INPUT.TEST_INPUT_TYPE = 'default'
|
48 |
+
|
49 |
+
_C.FIND_UNUSED_PARAM = True
|
50 |
+
_C.USE_ACT_CHECKPOINT = True
|
iChat/models/grit_src/grit/custom_solver.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
|
3 |
+
import itertools
|
4 |
+
from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from detectron2.config import CfgNode
|
8 |
+
|
9 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
10 |
+
|
11 |
+
|
12 |
+
def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
|
13 |
+
params: List[Dict[str, Any]] = []
|
14 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
15 |
+
optimizer_type = cfg.SOLVER.OPTIMIZER
|
16 |
+
|
17 |
+
for key, value in model.named_parameters(recurse=True):
|
18 |
+
if not value.requires_grad:
|
19 |
+
continue
|
20 |
+
# Avoid duplicating parameters
|
21 |
+
if value in memo:
|
22 |
+
continue
|
23 |
+
memo.add(value)
|
24 |
+
lr = cfg.SOLVER.BASE_LR
|
25 |
+
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
26 |
+
|
27 |
+
if cfg.SOLVER.VIT_LAYER_DECAY:
|
28 |
+
lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
|
29 |
+
|
30 |
+
param = {"params": [value], "lr": lr}
|
31 |
+
if optimizer_type != 'ADAMW':
|
32 |
+
param['weight_decay'] = weight_decay
|
33 |
+
params += [param]
|
34 |
+
|
35 |
+
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
|
36 |
+
# detectron2 doesn't have full model gradient clipping now
|
37 |
+
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
|
38 |
+
enable = (
|
39 |
+
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
|
40 |
+
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
|
41 |
+
and clip_norm_val > 0.0
|
42 |
+
)
|
43 |
+
|
44 |
+
class FullModelGradientClippingOptimizer(optim):
|
45 |
+
def step(self, closure=None):
|
46 |
+
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
|
47 |
+
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
|
48 |
+
super().step(closure=closure)
|
49 |
+
|
50 |
+
return FullModelGradientClippingOptimizer if enable else optim
|
51 |
+
|
52 |
+
|
53 |
+
if optimizer_type == 'SGD':
|
54 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
|
55 |
+
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
|
56 |
+
nesterov=cfg.SOLVER.NESTEROV
|
57 |
+
)
|
58 |
+
elif optimizer_type == 'ADAMW':
|
59 |
+
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
|
60 |
+
params, cfg.SOLVER.BASE_LR,
|
61 |
+
weight_decay=cfg.SOLVER.WEIGHT_DECAY
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError(f"no optimizer type {optimizer_type}")
|
65 |
+
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
|
66 |
+
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
|
67 |
+
return optimizer
|
68 |
+
|
69 |
+
|
70 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
71 |
+
"""
|
72 |
+
Calculate lr decay rate for different ViT blocks.
|
73 |
+
Args:
|
74 |
+
name (string): parameter name.
|
75 |
+
lr_decay_rate (float): base lr decay rate.
|
76 |
+
num_layers (int): number of ViT blocks.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
lr decay rate for the given parameter.
|
80 |
+
"""
|
81 |
+
layer_id = num_layers + 1
|
82 |
+
if name.startswith("backbone"):
|
83 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
84 |
+
layer_id = 0
|
85 |
+
elif ".blocks." in name and ".residual." not in name:
|
86 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
87 |
+
|
88 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
iChat/models/grit_src/grit/data/custom_build_augmentation.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from detectron2.data import transforms as T
|
3 |
+
from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
|
4 |
+
|
5 |
+
|
6 |
+
def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
|
7 |
+
min_size=None, max_size=None):
|
8 |
+
"""
|
9 |
+
Create a list of default :class:`Augmentation` from config.
|
10 |
+
Now it includes resizing and flipping.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
list[Augmentation]
|
14 |
+
"""
|
15 |
+
if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
|
16 |
+
if is_train:
|
17 |
+
min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
|
18 |
+
max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
|
19 |
+
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
|
20 |
+
else:
|
21 |
+
min_size = cfg.INPUT.MIN_SIZE_TEST
|
22 |
+
max_size = cfg.INPUT.MAX_SIZE_TEST
|
23 |
+
sample_style = "choice"
|
24 |
+
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
|
25 |
+
elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
26 |
+
if is_train:
|
27 |
+
scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
|
28 |
+
size = cfg.INPUT.TRAIN_SIZE if size is None else size
|
29 |
+
else:
|
30 |
+
scale = (1, 1)
|
31 |
+
size = cfg.INPUT.TEST_SIZE
|
32 |
+
augmentation = [EfficientDetResizeCrop(size, scale)]
|
33 |
+
else:
|
34 |
+
assert 0, cfg.INPUT.CUSTOM_AUG
|
35 |
+
|
36 |
+
if is_train:
|
37 |
+
augmentation.append(T.RandomFlip())
|
38 |
+
return augmentation
|
39 |
+
|
40 |
+
|
41 |
+
build_custom_transform_gen = build_custom_augmentation
|
42 |
+
"""
|
43 |
+
Alias for backward-compatibility.
|
44 |
+
"""
|
iChat/models/grit_src/grit/data/custom_dataset_dataloader.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_dataloader.py
|
3 |
+
import operator
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from detectron2.utils.comm import get_world_size
|
7 |
+
|
8 |
+
from detectron2.config import configurable
|
9 |
+
from torch.utils.data.sampler import BatchSampler, Sampler
|
10 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
11 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
12 |
+
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
|
13 |
+
from detectron2.data.samplers import TrainingSampler
|
14 |
+
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
|
15 |
+
from detectron2.data.build import filter_images_with_only_crowd_annotations
|
16 |
+
from detectron2.data.build import filter_images_with_few_keypoints
|
17 |
+
from detectron2.data.build import check_metadata_consistency
|
18 |
+
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
|
19 |
+
from detectron2.utils import comm
|
20 |
+
import itertools
|
21 |
+
from typing import Optional
|
22 |
+
|
23 |
+
|
24 |
+
def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
|
25 |
+
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
|
26 |
+
if 'MultiDataset' in sampler_name:
|
27 |
+
dataset_dicts = get_detection_dataset_dicts_with_source(
|
28 |
+
cfg.DATASETS.TRAIN,
|
29 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
30 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
31 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
32 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
dataset_dicts = get_detection_dataset_dicts(
|
36 |
+
cfg.DATASETS.TRAIN,
|
37 |
+
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
|
38 |
+
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
39 |
+
if cfg.MODEL.KEYPOINT_ON else 0,
|
40 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
41 |
+
)
|
42 |
+
|
43 |
+
if mapper is None:
|
44 |
+
mapper = DatasetMapper(cfg, True)
|
45 |
+
|
46 |
+
if sampler is not None:
|
47 |
+
pass
|
48 |
+
elif sampler_name == "TrainingSampler":
|
49 |
+
sampler = TrainingSampler(len(dataset))
|
50 |
+
elif sampler_name == "MultiDatasetSampler":
|
51 |
+
sampler = MultiDatasetSampler(
|
52 |
+
dataset_dicts,
|
53 |
+
dataset_ratio=cfg.DATALOADER.DATASET_RATIO,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
57 |
+
|
58 |
+
return {
|
59 |
+
"dataset": dataset_dicts,
|
60 |
+
"sampler": sampler,
|
61 |
+
"mapper": mapper,
|
62 |
+
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
|
63 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
64 |
+
'dataset_bs': cfg.DATALOADER.DATASET_BS,
|
65 |
+
'num_datasets': len(cfg.DATASETS.TRAIN)
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
@configurable(from_config=_custom_train_loader_from_config)
|
70 |
+
def build_custom_train_loader(
|
71 |
+
dataset, *, mapper, sampler,
|
72 |
+
total_batch_size=16,
|
73 |
+
num_workers=0,
|
74 |
+
num_datasets=1,
|
75 |
+
dataset_bs=1
|
76 |
+
):
|
77 |
+
|
78 |
+
if isinstance(dataset, list):
|
79 |
+
dataset = DatasetFromList(dataset, copy=False)
|
80 |
+
if mapper is not None:
|
81 |
+
dataset = MapDataset(dataset, mapper)
|
82 |
+
if sampler is None:
|
83 |
+
sampler = TrainingSampler(len(dataset))
|
84 |
+
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
|
85 |
+
|
86 |
+
return build_dataset_batch_data_loader(
|
87 |
+
dataset_bs,
|
88 |
+
dataset,
|
89 |
+
sampler,
|
90 |
+
total_batch_size,
|
91 |
+
num_datasets=num_datasets,
|
92 |
+
num_workers=num_workers,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def build_dataset_batch_data_loader(
|
97 |
+
dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0
|
98 |
+
):
|
99 |
+
|
100 |
+
world_size = get_world_size()
|
101 |
+
assert (
|
102 |
+
total_batch_size > 0 and total_batch_size % world_size == 0
|
103 |
+
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
|
104 |
+
total_batch_size, world_size
|
105 |
+
)
|
106 |
+
|
107 |
+
data_loader = torch.utils.data.DataLoader(
|
108 |
+
dataset,
|
109 |
+
sampler=sampler,
|
110 |
+
num_workers=num_workers,
|
111 |
+
batch_sampler=None,
|
112 |
+
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
|
113 |
+
worker_init_fn=worker_init_reset_seed,
|
114 |
+
)
|
115 |
+
|
116 |
+
if num_datasets > 1:
|
117 |
+
return MultiDatasets(data_loader, dataset_bs, num_datasets)
|
118 |
+
else:
|
119 |
+
return SingleDataset(data_loader, dataset_bs)
|
120 |
+
|
121 |
+
|
122 |
+
def get_detection_dataset_dicts_with_source(
|
123 |
+
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
|
124 |
+
):
|
125 |
+
assert len(dataset_names)
|
126 |
+
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
|
127 |
+
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
|
128 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
129 |
+
|
130 |
+
for source_id, (dataset_name, dicts) in \
|
131 |
+
enumerate(zip(dataset_names, dataset_dicts)):
|
132 |
+
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
133 |
+
for d in dicts:
|
134 |
+
d['dataset_source'] = source_id
|
135 |
+
|
136 |
+
if "annotations" in dicts[0]:
|
137 |
+
try:
|
138 |
+
class_names = MetadataCatalog.get(dataset_name).thing_classes
|
139 |
+
check_metadata_consistency("thing_classes", dataset_name)
|
140 |
+
print_instances_class_histogram(dicts, class_names)
|
141 |
+
except AttributeError: # class names are not available for this dataset
|
142 |
+
pass
|
143 |
+
|
144 |
+
assert proposal_files is None
|
145 |
+
|
146 |
+
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
147 |
+
|
148 |
+
has_instances = "annotations" in dataset_dicts[0]
|
149 |
+
if filter_empty and has_instances:
|
150 |
+
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
|
151 |
+
if min_keypoints > 0 and has_instances:
|
152 |
+
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
153 |
+
|
154 |
+
return dataset_dicts
|
155 |
+
|
156 |
+
|
157 |
+
class MultiDatasetSampler(Sampler):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
dataset_dicts,
|
161 |
+
dataset_ratio,
|
162 |
+
seed: Optional[int] = None,
|
163 |
+
):
|
164 |
+
sizes = [0 for _ in range(len(dataset_ratio))]
|
165 |
+
for d in dataset_dicts:
|
166 |
+
sizes[d['dataset_source']] += 1
|
167 |
+
print('dataset sizes', sizes)
|
168 |
+
self.sizes = sizes
|
169 |
+
assert len(dataset_ratio) == len(sizes), \
|
170 |
+
'length of dataset ratio {} should be equal to number if dataset {}'.format(
|
171 |
+
len(dataset_ratio), len(sizes)
|
172 |
+
)
|
173 |
+
if seed is None:
|
174 |
+
seed = comm.shared_random_seed()
|
175 |
+
self._seed = int(seed)
|
176 |
+
self._rank = comm.get_rank()
|
177 |
+
self._world_size = comm.get_world_size()
|
178 |
+
|
179 |
+
self.dataset_ids = torch.tensor(
|
180 |
+
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
|
181 |
+
self.dataset_ratio = dataset_ratio
|
182 |
+
|
183 |
+
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
|
184 |
+
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
|
185 |
+
dataset_weight = torch.cat(dataset_weight)
|
186 |
+
|
187 |
+
self.weights = dataset_weight
|
188 |
+
self.sample_epoch_size = len(self.weights)
|
189 |
+
|
190 |
+
def __iter__(self):
|
191 |
+
start = self._rank
|
192 |
+
yield from itertools.islice(
|
193 |
+
self._infinite_indices(), start, None, self._world_size)
|
194 |
+
|
195 |
+
def _infinite_indices(self):
|
196 |
+
g = torch.Generator()
|
197 |
+
g.manual_seed(self._seed)
|
198 |
+
while True:
|
199 |
+
if len(self.dataset_ratio) > 1:
|
200 |
+
# multiple datasets
|
201 |
+
ids = torch.multinomial(
|
202 |
+
self.weights, self.sample_epoch_size, generator=g,
|
203 |
+
replacement=True)
|
204 |
+
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
|
205 |
+
for i in range(len(self.sizes))]
|
206 |
+
yield from ids
|
207 |
+
else:
|
208 |
+
# single dataset
|
209 |
+
yield from torch.randperm(self.sizes[0], generator=g).tolist()
|
210 |
+
|
211 |
+
|
212 |
+
class SingleDataset(torch.utils.data.IterableDataset):
|
213 |
+
def __init__(self, dataset, batch_sizes):
|
214 |
+
self.dataset = dataset
|
215 |
+
self.batch_sizes = batch_sizes
|
216 |
+
self._buckets = [[] for _ in range(2)]
|
217 |
+
|
218 |
+
def __iter__(self):
|
219 |
+
for d in self.dataset:
|
220 |
+
w, h = d["width"], d["height"]
|
221 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
222 |
+
bucket_id = aspect_ratio_bucket_id
|
223 |
+
bucket = self._buckets[bucket_id]
|
224 |
+
bucket.append(d)
|
225 |
+
if len(bucket) == self.batch_sizes:
|
226 |
+
yield bucket[:]
|
227 |
+
del bucket[:]
|
228 |
+
|
229 |
+
|
230 |
+
class MultiDatasets(torch.utils.data.IterableDataset):
|
231 |
+
def __init__(self, dataset, batch_sizes, num_datasets):
|
232 |
+
self.dataset = dataset
|
233 |
+
self.batch_sizes = batch_sizes
|
234 |
+
self._buckets = [[] for _ in range(2 * num_datasets)]
|
235 |
+
self.iter_idx = 0
|
236 |
+
self.num_datasets = num_datasets
|
237 |
+
|
238 |
+
def __iter__(self):
|
239 |
+
for d in self.dataset:
|
240 |
+
w, h = d["width"], d["height"]
|
241 |
+
aspect_ratio_bucket_id = 0 if w > h else 1
|
242 |
+
bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
|
243 |
+
bucket = self._buckets[bucket_id]
|
244 |
+
if len(bucket) < self.batch_sizes:
|
245 |
+
bucket.append(d)
|
246 |
+
selected_dataset = self.iter_idx % self.num_datasets
|
247 |
+
if len(bucket) == self.batch_sizes and selected_dataset == d['dataset_source']:
|
248 |
+
self.iter_idx += 1
|
249 |
+
yield bucket[:]
|
250 |
+
del bucket[:]
|
iChat/models/grit_src/grit/data/custom_dataset_mapper.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
# Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_mapper.py
|
3 |
+
import copy
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from detectron2.config import configurable
|
8 |
+
|
9 |
+
from detectron2.data import detection_utils as utils
|
10 |
+
from detectron2.data import transforms as T
|
11 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
12 |
+
from .custom_build_augmentation import build_custom_augmentation
|
13 |
+
from itertools import compress
|
14 |
+
import logging
|
15 |
+
|
16 |
+
__all__ = ["CustomDatasetMapper", "ObjDescription"]
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class CustomDatasetMapper(DatasetMapper):
|
21 |
+
@configurable
|
22 |
+
def __init__(self, is_train: bool,
|
23 |
+
dataset_augs=[],
|
24 |
+
**kwargs):
|
25 |
+
if is_train:
|
26 |
+
self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
|
27 |
+
super().__init__(is_train, **kwargs)
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_config(cls, cfg, is_train: bool = True):
|
31 |
+
ret = super().from_config(cfg, is_train)
|
32 |
+
if is_train:
|
33 |
+
if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
|
34 |
+
dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
|
35 |
+
dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
|
36 |
+
ret['dataset_augs'] = [
|
37 |
+
build_custom_augmentation(cfg, True, scale, size) \
|
38 |
+
for scale, size in zip(dataset_scales, dataset_sizes)]
|
39 |
+
else:
|
40 |
+
assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
|
41 |
+
min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
|
42 |
+
max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
|
43 |
+
ret['dataset_augs'] = [
|
44 |
+
build_custom_augmentation(
|
45 |
+
cfg, True, min_size=mi, max_size=ma) \
|
46 |
+
for mi, ma in zip(min_sizes, max_sizes)]
|
47 |
+
else:
|
48 |
+
ret['dataset_augs'] = []
|
49 |
+
|
50 |
+
return ret
|
51 |
+
|
52 |
+
def __call__(self, dataset_dict):
|
53 |
+
dataset_dict_out = self.prepare_data(dataset_dict)
|
54 |
+
|
55 |
+
# When augmented image is too small, do re-augmentation
|
56 |
+
retry = 0
|
57 |
+
while (dataset_dict_out["image"].shape[1] < 32 or dataset_dict_out["image"].shape[2] < 32):
|
58 |
+
retry += 1
|
59 |
+
if retry == 100:
|
60 |
+
logger.info('Retry 100 times for augmentation. Make sure the image size is not too small.')
|
61 |
+
logger.info('Find image information below')
|
62 |
+
logger.info(dataset_dict)
|
63 |
+
dataset_dict_out = self.prepare_data(dataset_dict)
|
64 |
+
|
65 |
+
return dataset_dict_out
|
66 |
+
|
67 |
+
def prepare_data(self, dataset_dict_in):
|
68 |
+
dataset_dict = copy.deepcopy(dataset_dict_in)
|
69 |
+
if 'file_name' in dataset_dict:
|
70 |
+
ori_image = utils.read_image(
|
71 |
+
dataset_dict["file_name"], format=self.image_format)
|
72 |
+
else:
|
73 |
+
ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
|
74 |
+
ori_image = utils._apply_exif_orientation(ori_image)
|
75 |
+
ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
|
76 |
+
utils.check_image_size(dataset_dict, ori_image)
|
77 |
+
|
78 |
+
aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None)
|
79 |
+
if self.is_train:
|
80 |
+
transforms = \
|
81 |
+
self.dataset_augs[dataset_dict['dataset_source']](aug_input)
|
82 |
+
else:
|
83 |
+
transforms = self.augmentations(aug_input)
|
84 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
85 |
+
|
86 |
+
image_shape = image.shape[:2]
|
87 |
+
dataset_dict["image"] = torch.as_tensor(
|
88 |
+
np.ascontiguousarray(image.transpose(2, 0, 1)))
|
89 |
+
|
90 |
+
if not self.is_train:
|
91 |
+
# USER: Modify this if you want to keep them for some reason.
|
92 |
+
dataset_dict.pop("annotations", None)
|
93 |
+
return dataset_dict
|
94 |
+
|
95 |
+
if "annotations" in dataset_dict:
|
96 |
+
if len(dataset_dict["annotations"]) > 0:
|
97 |
+
object_descriptions = [an['object_description'] for an in dataset_dict["annotations"]]
|
98 |
+
else:
|
99 |
+
object_descriptions = []
|
100 |
+
# USER: Modify this if you want to keep them for some reason.
|
101 |
+
for anno in dataset_dict["annotations"]:
|
102 |
+
if not self.use_instance_mask:
|
103 |
+
anno.pop("segmentation", None)
|
104 |
+
if not self.use_keypoint:
|
105 |
+
anno.pop("keypoints", None)
|
106 |
+
|
107 |
+
all_annos = [
|
108 |
+
(utils.transform_instance_annotations(
|
109 |
+
obj, transforms, image_shape,
|
110 |
+
keypoint_hflip_indices=self.keypoint_hflip_indices,
|
111 |
+
), obj.get("iscrowd", 0))
|
112 |
+
for obj in dataset_dict.pop("annotations")
|
113 |
+
]
|
114 |
+
annos = [ann[0] for ann in all_annos if ann[1] == 0]
|
115 |
+
instances = utils.annotations_to_instances(
|
116 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
117 |
+
)
|
118 |
+
|
119 |
+
instances.gt_object_descriptions = ObjDescription(object_descriptions)
|
120 |
+
|
121 |
+
del all_annos
|
122 |
+
if self.recompute_boxes:
|
123 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
124 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
125 |
+
|
126 |
+
return dataset_dict
|
127 |
+
|
128 |
+
|
129 |
+
class ObjDescription:
|
130 |
+
def __init__(self, object_descriptions):
|
131 |
+
self.data = object_descriptions
|
132 |
+
|
133 |
+
def __getitem__(self, item):
|
134 |
+
assert type(item) == torch.Tensor
|
135 |
+
assert item.dim() == 1
|
136 |
+
if len(item) > 0:
|
137 |
+
assert item.dtype == torch.int64 or item.dtype == torch.bool
|
138 |
+
if item.dtype == torch.int64:
|
139 |
+
return ObjDescription([self.data[x.item()] for x in item])
|
140 |
+
elif item.dtype == torch.bool:
|
141 |
+
return ObjDescription(list(compress(self.data, item)))
|
142 |
+
|
143 |
+
return ObjDescription(list(compress(self.data, item)))
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.data)
|
147 |
+
|
148 |
+
def __repr__(self):
|
149 |
+
return "ObjDescription({})".format(self.data)
|
iChat/models/grit_src/grit/data/datasets/grit_coco.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_GRiTcoco_json", "register_GRiTcoco_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_GRiTcoco_instances(name, metadata, json_file, image_root):
|
15 |
+
"""
|
16 |
+
"""
|
17 |
+
DatasetCatalog.register(name, lambda: load_GRiTcoco_json(
|
18 |
+
json_file, image_root, name))
|
19 |
+
MetadataCatalog.get(name).set(
|
20 |
+
json_file=json_file, image_root=image_root,
|
21 |
+
evaluator_type="coco", **metadata
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_GRiTcoco_meta():
|
26 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
27 |
+
categories = sorted(categories, key=lambda x: x["id"])
|
28 |
+
thing_classes = [k["name"] for k in categories]
|
29 |
+
meta = {"thing_classes": thing_classes}
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def load_GRiTcoco_json(json_file, image_root, dataset_name=None):
|
34 |
+
'''
|
35 |
+
Load COCO class name text for object description for GRiT
|
36 |
+
'''
|
37 |
+
|
38 |
+
json_file = PathManager.get_local_path(json_file)
|
39 |
+
|
40 |
+
timer = Timer()
|
41 |
+
lvis_api = LVIS(json_file)
|
42 |
+
if timer.seconds() > 1:
|
43 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
44 |
+
json_file, timer.seconds()))
|
45 |
+
|
46 |
+
class_names = {}
|
47 |
+
sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
|
48 |
+
for x in sort_cat:
|
49 |
+
class_names[x['id']] = x['name']
|
50 |
+
|
51 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
52 |
+
imgs = lvis_api.load_imgs(img_ids)
|
53 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
54 |
+
|
55 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
56 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
57 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
58 |
+
|
59 |
+
imgs_anns = list(zip(imgs, anns))
|
60 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
61 |
+
len(imgs_anns), json_file))
|
62 |
+
|
63 |
+
dataset_dicts = []
|
64 |
+
|
65 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
66 |
+
record = {}
|
67 |
+
if "file_name" in img_dict:
|
68 |
+
file_name = img_dict["file_name"]
|
69 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
70 |
+
|
71 |
+
record["height"] = int(img_dict["height"])
|
72 |
+
record["width"] = int(img_dict["width"])
|
73 |
+
image_id = record["image_id"] = img_dict["id"]
|
74 |
+
|
75 |
+
objs = []
|
76 |
+
for anno in anno_dict_list:
|
77 |
+
assert anno["image_id"] == image_id
|
78 |
+
if anno.get('iscrowd', 0) > 0:
|
79 |
+
continue
|
80 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
81 |
+
obj["category_id"] = 0
|
82 |
+
obj["object_description"] = class_names[anno['category_id']]
|
83 |
+
if 'segmentation' in anno:
|
84 |
+
segm = anno["segmentation"]
|
85 |
+
valid_segm = [poly for poly in segm \
|
86 |
+
if len(poly) % 2 == 0 and len(poly) >= 6]
|
87 |
+
if not len(segm) == len(valid_segm):
|
88 |
+
print('Annotation contains an invalid polygon with < 3 points')
|
89 |
+
assert len(segm) > 0
|
90 |
+
obj["segmentation"] = segm
|
91 |
+
objs.append(obj)
|
92 |
+
record["annotations"] = objs
|
93 |
+
if len(record["annotations"]) == 0:
|
94 |
+
continue
|
95 |
+
record["task"] = "ObjectDet"
|
96 |
+
dataset_dicts.append(record)
|
97 |
+
|
98 |
+
return dataset_dicts
|
99 |
+
|
100 |
+
|
101 |
+
_CUSTOM_SPLITS_LVIS = {
|
102 |
+
"GRiT_coco2017_train": ("coco/train2017/", "coco/annotations/instances_train2017.json"),
|
103 |
+
}
|
104 |
+
|
105 |
+
|
106 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
107 |
+
register_GRiTcoco_instances(
|
108 |
+
key,
|
109 |
+
get_GRiTcoco_meta(),
|
110 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
111 |
+
os.path.join("datasets", image_root),
|
112 |
+
)
|
iChat/models/grit_src/grit/data/datasets/object365.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_o365_json", "register_o365_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_o365_instances(name, metadata, json_file, image_root):
|
15 |
+
DatasetCatalog.register(name, lambda: load_o365_json(
|
16 |
+
json_file, image_root, name))
|
17 |
+
MetadataCatalog.get(name).set(
|
18 |
+
json_file=json_file, image_root=image_root,
|
19 |
+
evaluator_type="lvis", **metadata
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def get_o365_meta():
|
24 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
25 |
+
o365_categories = sorted(categories, key=lambda x: x["id"])
|
26 |
+
thing_classes = [k["name"] for k in o365_categories]
|
27 |
+
meta = {"thing_classes": thing_classes}
|
28 |
+
return meta
|
29 |
+
|
30 |
+
|
31 |
+
def load_o365_json(json_file, image_root, dataset_name=None):
|
32 |
+
'''
|
33 |
+
Load Object365 class name text for object description for GRiT
|
34 |
+
'''
|
35 |
+
|
36 |
+
json_file = PathManager.get_local_path(json_file)
|
37 |
+
|
38 |
+
timer = Timer()
|
39 |
+
lvis_api = LVIS(json_file)
|
40 |
+
if timer.seconds() > 1:
|
41 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
42 |
+
json_file, timer.seconds()))
|
43 |
+
|
44 |
+
class_names = {}
|
45 |
+
sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
|
46 |
+
for x in sort_cat:
|
47 |
+
if '/' in x['name']:
|
48 |
+
text = ''
|
49 |
+
for xx in x['name'].split('/'):
|
50 |
+
text += xx
|
51 |
+
text += ' '
|
52 |
+
text = text[:-1]
|
53 |
+
else:
|
54 |
+
text = x['name']
|
55 |
+
class_names[x['id']] = text
|
56 |
+
|
57 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
58 |
+
imgs = lvis_api.load_imgs(img_ids)
|
59 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
60 |
+
|
61 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
62 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
63 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
64 |
+
|
65 |
+
imgs_anns = list(zip(imgs, anns))
|
66 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
67 |
+
len(imgs_anns), json_file))
|
68 |
+
|
69 |
+
dataset_dicts = []
|
70 |
+
|
71 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
72 |
+
record = {}
|
73 |
+
if "file_name" in img_dict:
|
74 |
+
file_name = img_dict["file_name"]
|
75 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
76 |
+
|
77 |
+
record["height"] = int(img_dict["height"])
|
78 |
+
record["width"] = int(img_dict["width"])
|
79 |
+
image_id = record["image_id"] = img_dict["id"]
|
80 |
+
|
81 |
+
objs = []
|
82 |
+
for anno in anno_dict_list:
|
83 |
+
assert anno["image_id"] == image_id
|
84 |
+
if anno.get('iscrowd', 0) > 0:
|
85 |
+
continue
|
86 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
87 |
+
obj["category_id"] = 0
|
88 |
+
obj["object_description"] = class_names[anno['category_id']]
|
89 |
+
|
90 |
+
objs.append(obj)
|
91 |
+
record["annotations"] = objs
|
92 |
+
if len(record["annotations"]) == 0:
|
93 |
+
continue
|
94 |
+
record["task"] = "ObjectDet"
|
95 |
+
dataset_dicts.append(record)
|
96 |
+
|
97 |
+
return dataset_dicts
|
98 |
+
|
99 |
+
|
100 |
+
_CUSTOM_SPLITS_LVIS = {
|
101 |
+
"object365_train": ("object365/images/train/", "object365/annotations/train_v1.json"),
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
106 |
+
register_o365_instances(
|
107 |
+
key,
|
108 |
+
get_o365_meta(),
|
109 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
110 |
+
os.path.join("datasets", image_root),
|
111 |
+
)
|
iChat/models/grit_src/grit/data/datasets/vg.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from fvcore.common.timer import Timer
|
4 |
+
from detectron2.structures import BoxMode
|
5 |
+
from fvcore.common.file_io import PathManager
|
6 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
7 |
+
from lvis import LVIS
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
__all__ = ["load_vg_json", "register_vg_instances"]
|
12 |
+
|
13 |
+
|
14 |
+
def register_vg_instances(name, metadata, json_file, image_root):
|
15 |
+
"""
|
16 |
+
"""
|
17 |
+
DatasetCatalog.register(name, lambda: load_vg_json(
|
18 |
+
json_file, image_root, name))
|
19 |
+
MetadataCatalog.get(name).set(
|
20 |
+
json_file=json_file, image_root=image_root,
|
21 |
+
evaluator_type="vg", **metadata
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def get_vg_meta():
|
26 |
+
categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
|
27 |
+
vg_categories = sorted(categories, key=lambda x: x["id"])
|
28 |
+
thing_classes = [k["name"] for k in vg_categories]
|
29 |
+
meta = {"thing_classes": thing_classes}
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def load_vg_json(json_file, image_root, dataset_name=None):
|
34 |
+
|
35 |
+
json_file = PathManager.get_local_path(json_file)
|
36 |
+
|
37 |
+
timer = Timer()
|
38 |
+
lvis_api = LVIS(json_file)
|
39 |
+
if timer.seconds() > 1:
|
40 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(
|
41 |
+
json_file, timer.seconds()))
|
42 |
+
|
43 |
+
img_ids = sorted(lvis_api.imgs.keys())
|
44 |
+
imgs = lvis_api.load_imgs(img_ids)
|
45 |
+
anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
|
46 |
+
|
47 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
48 |
+
assert len(set(ann_ids)) == len(ann_ids), \
|
49 |
+
"Annotation ids in '{}' are not unique".format(json_file)
|
50 |
+
|
51 |
+
imgs_anns = list(zip(imgs, anns))
|
52 |
+
logger.info("Loaded {} images in the LVIS v1 format from {}".format(
|
53 |
+
len(imgs_anns), json_file))
|
54 |
+
|
55 |
+
dataset_dicts = []
|
56 |
+
|
57 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
58 |
+
record = {}
|
59 |
+
if "file_name" in img_dict:
|
60 |
+
file_name = img_dict["file_name"]
|
61 |
+
record["file_name"] = os.path.join(image_root, file_name)
|
62 |
+
|
63 |
+
record["height"] = int(img_dict["height"])
|
64 |
+
record["width"] = int(img_dict["width"])
|
65 |
+
image_id = record["image_id"] = img_dict["id"]
|
66 |
+
|
67 |
+
objs = []
|
68 |
+
for anno in anno_dict_list:
|
69 |
+
assert anno["image_id"] == image_id
|
70 |
+
if anno.get('iscrowd', 0) > 0:
|
71 |
+
continue
|
72 |
+
obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
|
73 |
+
obj["category_id"] = 0
|
74 |
+
obj["object_description"] = anno["caption"]
|
75 |
+
|
76 |
+
objs.append(obj)
|
77 |
+
record["annotations"] = objs
|
78 |
+
if len(record["annotations"]) == 0:
|
79 |
+
continue
|
80 |
+
record["task"] = "DenseCap"
|
81 |
+
dataset_dicts.append(record)
|
82 |
+
|
83 |
+
return dataset_dicts
|
84 |
+
|
85 |
+
|
86 |
+
_CUSTOM_SPLITS_LVIS = {
|
87 |
+
"vg_train": ("vg/images", "vg/annotations/train.json"),
|
88 |
+
"vg_test": ("vg/images", "vg/annotations/test.json"),
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
|
93 |
+
register_vg_instances(
|
94 |
+
key,
|
95 |
+
get_vg_meta(),
|
96 |
+
os.path.join("datasets", json_file) if "://" not in json_file else json_file,
|
97 |
+
os.path.join("datasets", image_root),
|
98 |
+
)
|
iChat/models/grit_src/grit/data/transforms/custom_augmentation_impl.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
4 |
+
# Modified by Xingyi Zhou
|
5 |
+
# The original code is under Apache-2.0 License
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from detectron2.data.transforms.augmentation import Augmentation
|
10 |
+
from .custom_transform import EfficientDetResizeCropTransform
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"EfficientDetResizeCrop",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
class EfficientDetResizeCrop(Augmentation):
|
18 |
+
"""
|
19 |
+
Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
|
20 |
+
If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self, size, scale, interp=Image.BILINEAR
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
self.target_size = (size, size)
|
30 |
+
self.scale = scale
|
31 |
+
self.interp = interp
|
32 |
+
|
33 |
+
def get_transform(self, img):
|
34 |
+
# Select a random scale factor.
|
35 |
+
scale_factor = np.random.uniform(*self.scale)
|
36 |
+
scaled_target_height = scale_factor * self.target_size[0]
|
37 |
+
scaled_target_width = scale_factor * self.target_size[1]
|
38 |
+
# Recompute the accurate scale_factor using rounded scaled image size.
|
39 |
+
width, height = img.shape[1], img.shape[0]
|
40 |
+
img_scale_y = scaled_target_height / height
|
41 |
+
img_scale_x = scaled_target_width / width
|
42 |
+
img_scale = min(img_scale_y, img_scale_x)
|
43 |
+
|
44 |
+
# Select non-zero random offset (x, y) if scaled image is larger than target size
|
45 |
+
scaled_h = int(height * img_scale)
|
46 |
+
scaled_w = int(width * img_scale)
|
47 |
+
offset_y = scaled_h - self.target_size[0]
|
48 |
+
offset_x = scaled_w - self.target_size[1]
|
49 |
+
offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
|
50 |
+
offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
|
51 |
+
return EfficientDetResizeCropTransform(
|
52 |
+
scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
|
iChat/models/grit_src/grit/data/transforms/custom_transform.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
# Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
|
4 |
+
# Modified by Xingyi Zhou
|
5 |
+
# The original code is under Apache-2.0 License
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from fvcore.transforms.transform import (
|
10 |
+
CropTransform,
|
11 |
+
HFlipTransform,
|
12 |
+
NoOpTransform,
|
13 |
+
Transform,
|
14 |
+
TransformList,
|
15 |
+
)
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
try:
|
19 |
+
import cv2 # noqa
|
20 |
+
except ImportError:
|
21 |
+
# OpenCV is an optional dependency at the moment
|
22 |
+
pass
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
"EfficientDetResizeCropTransform",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
class EfficientDetResizeCropTransform(Transform):
|
30 |
+
"""
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
|
34 |
+
target_size, interp=None):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
h, w (int): original image size
|
38 |
+
new_h, new_w (int): new image size
|
39 |
+
interp: PIL interpolation methods, defaults to bilinear.
|
40 |
+
"""
|
41 |
+
# TODO decide on PIL vs opencv
|
42 |
+
super().__init__()
|
43 |
+
if interp is None:
|
44 |
+
interp = Image.BILINEAR
|
45 |
+
self._set_attributes(locals())
|
46 |
+
|
47 |
+
def apply_image(self, img, interp=None):
|
48 |
+
assert len(img.shape) <= 4
|
49 |
+
|
50 |
+
if img.dtype == np.uint8:
|
51 |
+
pil_image = Image.fromarray(img)
|
52 |
+
interp_method = interp if interp is not None else self.interp
|
53 |
+
pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
|
54 |
+
ret = np.asarray(pil_image)
|
55 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
56 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
57 |
+
if len(ret.shape) <= 3:
|
58 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
59 |
+
else:
|
60 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
61 |
+
else:
|
62 |
+
# PIL only supports uint8
|
63 |
+
img = torch.from_numpy(img)
|
64 |
+
shape = list(img.shape)
|
65 |
+
shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
|
66 |
+
img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
|
67 |
+
_PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
|
68 |
+
mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
|
69 |
+
img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
|
70 |
+
shape[:2] = (self.scaled_h, self.scaled_w)
|
71 |
+
ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
|
72 |
+
right = min(self.scaled_w, self.offset_x + self.target_size[1])
|
73 |
+
lower = min(self.scaled_h, self.offset_y + self.target_size[0])
|
74 |
+
if len(ret.shape) <= 3:
|
75 |
+
ret = ret[self.offset_y: lower, self.offset_x: right]
|
76 |
+
else:
|
77 |
+
ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def apply_coords(self, coords):
|
82 |
+
coords[:, 0] = coords[:, 0] * self.img_scale
|
83 |
+
coords[:, 1] = coords[:, 1] * self.img_scale
|
84 |
+
coords[:, 0] -= self.offset_x
|
85 |
+
coords[:, 1] -= self.offset_y
|
86 |
+
return coords
|
87 |
+
|
88 |
+
|
89 |
+
def apply_segmentation(self, segmentation):
|
90 |
+
segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
|
91 |
+
return segmentation
|
92 |
+
|
93 |
+
|
94 |
+
def inverse(self):
|
95 |
+
raise NotImplementedError
|
96 |
+
|
97 |
+
|
98 |
+
def inverse_apply_coords(self, coords):
|
99 |
+
coords[:, 0] += self.offset_x
|
100 |
+
coords[:, 1] += self.offset_y
|
101 |
+
coords[:, 0] = coords[:, 0] / self.img_scale
|
102 |
+
coords[:, 1] = coords[:, 1] / self.img_scale
|
103 |
+
return coords
|
104 |
+
|
105 |
+
|
106 |
+
def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
|
107 |
+
"""
|
108 |
+
"""
|
109 |
+
idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
|
110 |
+
coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
|
111 |
+
coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
|
112 |
+
minxy = coords.min(axis=1)
|
113 |
+
maxxy = coords.max(axis=1)
|
114 |
+
trans_boxes = np.concatenate((minxy, maxxy), axis=1)
|
115 |
+
return trans_boxes
|
iChat/models/grit_src/grit/evaluation/eval.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from detectron2.structures import Boxes, BoxMode, pairwise_iou
|
5 |
+
from detectron2.utils.file_io import PathManager
|
6 |
+
import numpy as np
|
7 |
+
import pycocotools.mask as mask_util
|
8 |
+
from detectron2.evaluation.coco_evaluation import COCOEvaluator
|
9 |
+
from detectron2.evaluation.coco_evaluation import _evaluate_predictions_on_coco
|
10 |
+
|
11 |
+
|
12 |
+
class GRiTCOCOEvaluator(COCOEvaluator):
|
13 |
+
def process(self, inputs, outputs):
|
14 |
+
for input, output in zip(inputs, outputs):
|
15 |
+
prediction = {"image_id": input["image_id"]}
|
16 |
+
|
17 |
+
if "instances" in output:
|
18 |
+
instances = output["instances"].to(self._cpu_device)
|
19 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
|
20 |
+
|
21 |
+
if len(prediction) > 1:
|
22 |
+
self._predictions.append(prediction)
|
23 |
+
|
24 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
25 |
+
self._logger.info("Preparing results for COCO format ...")
|
26 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
27 |
+
tasks = self._tasks or self._tasks_from_predictions(coco_results)
|
28 |
+
|
29 |
+
if self._output_dir:
|
30 |
+
file_path = os.path.join(self._output_dir, "coco_instances_results.json")
|
31 |
+
self._logger.info("Saving results to {}".format(file_path))
|
32 |
+
with PathManager.open(file_path, "w") as f:
|
33 |
+
f.write(json.dumps(coco_results))
|
34 |
+
f.flush()
|
35 |
+
|
36 |
+
if not self._do_evaluation:
|
37 |
+
self._logger.info("Annotations are not available for evaluation.")
|
38 |
+
return
|
39 |
+
|
40 |
+
self._logger.info(
|
41 |
+
"Evaluating predictions with {} COCO API...".format(
|
42 |
+
"unofficial" if self._use_fast_impl else "official"
|
43 |
+
)
|
44 |
+
)
|
45 |
+
|
46 |
+
coco_results = self.convert_classname_to_id(coco_results)
|
47 |
+
|
48 |
+
for task in sorted(tasks):
|
49 |
+
assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
|
50 |
+
coco_eval = (
|
51 |
+
_evaluate_predictions_on_coco(
|
52 |
+
self._coco_api,
|
53 |
+
coco_results,
|
54 |
+
task,
|
55 |
+
kpt_oks_sigmas=self._kpt_oks_sigmas,
|
56 |
+
use_fast_impl=self._use_fast_impl,
|
57 |
+
img_ids=img_ids,
|
58 |
+
max_dets_per_image=self._max_dets_per_image,
|
59 |
+
)
|
60 |
+
if len(coco_results) > 0
|
61 |
+
else None # cocoapi does not handle empty results very well
|
62 |
+
)
|
63 |
+
|
64 |
+
res = self._derive_coco_results(
|
65 |
+
coco_eval, task, class_names=self._metadata.get("thing_classes")
|
66 |
+
)
|
67 |
+
self._results[task] = res
|
68 |
+
|
69 |
+
def convert_classname_to_id(self, results):
|
70 |
+
outputs = []
|
71 |
+
class_name_to_id = {}
|
72 |
+
categories = sorted(self._coco_api.dataset['categories'], key=lambda x: x['id'])
|
73 |
+
|
74 |
+
for cat in categories:
|
75 |
+
class_name_to_id[cat['name']] = cat['id']
|
76 |
+
|
77 |
+
for pred in results:
|
78 |
+
if pred['object_descriptions'] in class_name_to_id:
|
79 |
+
pred['category_id'] = class_name_to_id[pred['object_descriptions']]
|
80 |
+
del pred['object_descriptions']
|
81 |
+
outputs.append(pred)
|
82 |
+
|
83 |
+
return outputs
|
84 |
+
|
85 |
+
|
86 |
+
class GRiTVGEvaluator(COCOEvaluator):
|
87 |
+
def process(self, inputs, outputs):
|
88 |
+
for input, output in zip(inputs, outputs):
|
89 |
+
assert input["image_id"] == int(input['file_name'].split('/')[-1].split('.')[0])
|
90 |
+
prediction = {"image_id": input["image_id"]}
|
91 |
+
|
92 |
+
if "instances" in output:
|
93 |
+
instances = output["instances"].to(self._cpu_device)
|
94 |
+
prediction["instances"] = instances_to_coco_json(instances, input["image_id"], output_logits=True)
|
95 |
+
h = input['height']
|
96 |
+
w = input['width']
|
97 |
+
scale = 720.0 / max(h, w)
|
98 |
+
scaled_inst = []
|
99 |
+
for inst in prediction["instances"]:
|
100 |
+
inst['bbox'][0] = inst['bbox'][0] * scale
|
101 |
+
inst['bbox'][1] = inst['bbox'][1] * scale
|
102 |
+
inst['bbox'][2] = inst['bbox'][2] * scale
|
103 |
+
inst['bbox'][3] = inst['bbox'][3] * scale
|
104 |
+
scaled_inst.append(inst)
|
105 |
+
if len(scaled_inst) > 0:
|
106 |
+
prediction["instances"] = scaled_inst
|
107 |
+
if len(prediction) > 1:
|
108 |
+
self._predictions.append(prediction)
|
109 |
+
|
110 |
+
def _eval_predictions(self, predictions, img_ids=None):
|
111 |
+
'''
|
112 |
+
This is only for saving the results to json file
|
113 |
+
'''
|
114 |
+
self._logger.info("Preparing results for COCO format ...")
|
115 |
+
coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
|
116 |
+
|
117 |
+
if self._output_dir:
|
118 |
+
file_path = os.path.join(self._output_dir, "vg_instances_results.json")
|
119 |
+
self._logger.info("Saving results to {}".format(file_path))
|
120 |
+
with PathManager.open(file_path, "w") as f:
|
121 |
+
f.write(json.dumps(coco_results))
|
122 |
+
f.flush()
|
123 |
+
|
124 |
+
|
125 |
+
def instances_to_coco_json(instances, img_id, output_logits=False):
|
126 |
+
"""
|
127 |
+
Add object_descriptions and logit (if applicable) to
|
128 |
+
detectron2's instances_to_coco_json
|
129 |
+
"""
|
130 |
+
num_instance = len(instances)
|
131 |
+
if num_instance == 0:
|
132 |
+
return []
|
133 |
+
|
134 |
+
boxes = instances.pred_boxes.tensor.numpy()
|
135 |
+
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
136 |
+
boxes = boxes.tolist()
|
137 |
+
scores = instances.scores.tolist()
|
138 |
+
classes = instances.pred_classes.tolist()
|
139 |
+
object_descriptions = instances.pred_object_descriptions.data
|
140 |
+
if output_logits:
|
141 |
+
logits = instances.logits.tolist()
|
142 |
+
|
143 |
+
results = []
|
144 |
+
for k in range(num_instance):
|
145 |
+
result = {
|
146 |
+
"image_id": img_id,
|
147 |
+
"category_id": classes[k],
|
148 |
+
"bbox": boxes[k],
|
149 |
+
"score": scores[k],
|
150 |
+
'object_descriptions': object_descriptions[k],
|
151 |
+
}
|
152 |
+
if output_logits:
|
153 |
+
result["logit"] = logits[k]
|
154 |
+
|
155 |
+
results.append(result)
|
156 |
+
return results
|