shikunl commited on
Commit
b734d92
•
1 Parent(s): 59567a9

Reset again!

Browse files
Files changed (47) hide show
  1. .gitattributes +0 -34
  2. .gitignore +0 -163
  3. .gitmodules +0 -3
  4. .idea/.gitignore +8 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +4 -0
  7. .idea/modules.xml +8 -0
  8. .idea/prismer_demo.iml +8 -0
  9. .idea/vcs.xml +7 -0
  10. .pre-commit-config.yaml +0 -36
  11. .style.yapf +0 -5
  12. app.py +9 -0
  13. app_caption.py +13 -3
  14. patch +82 -0
  15. prismer/.gitignore +10 -0
  16. prismer/LICENSE +97 -0
  17. prismer/README.md +156 -0
  18. prismer/dataset/__init__.py +12 -1
  19. prismer/dataset/ade_features.pt +0 -0
  20. prismer/dataset/background_features.pt +0 -0
  21. prismer/dataset/caption_dataset.py +1 -1
  22. prismer/dataset/classification_dataset.py +72 -0
  23. prismer/dataset/clip_pca.pkl +0 -0
  24. prismer/dataset/coco_features.pt +0 -0
  25. prismer/dataset/detection_features.pt +0 -0
  26. prismer/dataset/pretrain_dataset.py +73 -0
  27. prismer/dataset/utils.py +4 -8
  28. prismer/dataset/vqa_dataset.py +0 -2
  29. prismer/experts/generate_depth.py +1 -1
  30. prismer/experts/generate_edge.py +1 -1
  31. prismer/experts/generate_normal.py +1 -1
  32. prismer/experts/generate_objdet.py +1 -1
  33. prismer/experts/generate_ocrdet.py +1 -1
  34. prismer/experts/generate_segmentation.py +1 -1
  35. prismer/{images → helpers/images}/COCO_test2015_000000000014.jpg +0 -0
  36. prismer/{images → helpers/images}/COCO_test2015_000000000016.jpg +0 -0
  37. prismer/{images → helpers/images}/COCO_test2015_000000000019.jpg +0 -0
  38. prismer/{images → helpers/images}/COCO_test2015_000000000128.jpg +0 -0
  39. prismer/{images → helpers/images}/COCO_test2015_000000000155.jpg +0 -0
  40. prismer/helpers/intro.png +0 -0
  41. prismer/model/prismer.py +1 -4
  42. prismer/requirements.txt +19 -0
  43. prismer/train_caption.py +208 -0
  44. prismer/train_classification.py +164 -0
  45. prismer/train_pretrain.py +140 -0
  46. prismer/train_vqa.py +180 -0
  47. prismer_model.py +82 -26
.gitattributes DELETED
@@ -1,34 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,163 +0,0 @@
1
- cache/
2
- .idea
3
-
4
- # Byte-compiled / optimized / DLL files
5
- __pycache__/
6
- *.py[cod]
7
- *$py.class
8
-
9
- # C extensions
10
- *.so
11
-
12
- # Distribution / packaging
13
- .Python
14
- build/
15
- develop-eggs/
16
- dist/
17
- downloads/
18
- eggs/
19
- .eggs/
20
- lib/
21
- lib64/
22
- parts/
23
- sdist/
24
- var/
25
- wheels/
26
- share/python-wheels/
27
- *.egg-info/
28
- .installed.cfg
29
- *.egg
30
- MANIFEST
31
-
32
- # PyInstaller
33
- # Usually these files are written by a python script from a template
34
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
- *.manifest
36
- *.spec
37
-
38
- # Installer logs
39
- pip-log.txt
40
- pip-delete-this-directory.txt
41
-
42
- # Unit test / coverage reports
43
- htmlcov/
44
- .tox/
45
- .nox/
46
- .coverage
47
- .coverage.*
48
- .cache
49
- nosetests.xml
50
- coverage.xml
51
- *.cover
52
- *.py,cover
53
- .hypothesis/
54
- .pytest_cache/
55
- cover/
56
-
57
- # Translations
58
- *.mo
59
- *.pot
60
-
61
- # Django stuff:
62
- *.log
63
- local_settings.py
64
- db.sqlite3
65
- db.sqlite3-journal
66
-
67
- # Flask stuff:
68
- instance/
69
- .webassets-cache
70
-
71
- # Scrapy stuff:
72
- .scrapy
73
-
74
- # Sphinx documentation
75
- docs/_build/
76
-
77
- # PyBuilder
78
- .pybuilder/
79
- target/
80
-
81
- # Jupyter Notebook
82
- .ipynb_checkpoints
83
-
84
- # IPython
85
- profile_default/
86
- ipython_config.py
87
-
88
- # pyenv
89
- # For a library or package, you might want to ignore these files since the code is
90
- # intended to run in multiple environments; otherwise, check them in:
91
- # .python-version
92
-
93
- # pipenv
94
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
- # install all needed dependencies.
98
- #Pipfile.lock
99
-
100
- # poetry
101
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
- # This is especially recommended for binary packages to ensure reproducibility, and is more
103
- # commonly ignored for libraries.
104
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
- #poetry.lock
106
-
107
- # pdm
108
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
- #pdm.lock
110
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
- # in version control.
112
- # https://pdm.fming.dev/#use-with-ide
113
- .pdm.toml
114
-
115
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
- __pypackages__/
117
-
118
- # Celery stuff
119
- celerybeat-schedule
120
- celerybeat.pid
121
-
122
- # SageMath parsed files
123
- *.sage.py
124
-
125
- # Environments
126
- .env
127
- .venv
128
- env/
129
- venv/
130
- ENV/
131
- env.bak/
132
- venv.bak/
133
-
134
- # Spyder project settings
135
- .spyderproject
136
- .spyproject
137
-
138
- # Rope project settings
139
- .ropeproject
140
-
141
- # mkdocs documentation
142
- /site
143
-
144
- # mypy
145
- .mypy_cache/
146
- .dmypy.json
147
- dmypy.json
148
-
149
- # Pyre type checker
150
- .pyre/
151
-
152
- # pytype static type analyzer
153
- .pytype/
154
-
155
- # Cython debug symbols
156
- cython_debug/
157
-
158
- # PyCharm
159
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
- # and can be added to the global gitignore or merged into this file. For a more nuclear
162
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
- #.idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitmodules DELETED
@@ -1,3 +0,0 @@
1
- [submodule "prismer"]
2
- path = prismer
3
- url = https://github.com/nvlabs/prismer
 
 
 
 
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/prismer_demo.iml" filepath="$PROJECT_DIR$/.idea/prismer_demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/prismer_demo.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ <mapping directory="$PROJECT_DIR$/prismer" vcs="Git" />
6
+ </component>
7
+ </project>
.pre-commit-config.yaml DELETED
@@ -1,36 +0,0 @@
1
- repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.2.0
4
- hooks:
5
- - id: check-executables-have-shebangs
6
- - id: check-json
7
- - id: check-merge-conflict
8
- - id: check-shebang-scripts-are-executable
9
- - id: check-toml
10
- - id: check-yaml
11
- - id: double-quote-string-fixer
12
- - id: end-of-file-fixer
13
- - id: mixed-line-ending
14
- args: ['--fix=lf']
15
- - id: requirements-txt-fixer
16
- - id: trailing-whitespace
17
- - repo: https://github.com/myint/docformatter
18
- rev: v1.4
19
- hooks:
20
- - id: docformatter
21
- args: ['--in-place']
22
- - repo: https://github.com/pycqa/isort
23
- rev: 5.12.0
24
- hooks:
25
- - id: isort
26
- - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v0.991
28
- hooks:
29
- - id: mypy
30
- args: ['--ignore-missing-imports']
31
- additional_dependencies: ['types-python-slugify']
32
- - repo: https://github.com/google/yapf
33
- rev: v0.32.0
34
- hooks:
35
- - id: yapf
36
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
app.py CHANGED
@@ -5,11 +5,20 @@ from __future__ import annotations
5
  import os
6
  import shutil
7
  import subprocess
 
8
  import gradio as gr
9
 
 
 
 
 
 
 
 
10
  from app_caption import create_demo as create_demo_caption
11
  from prismer_model import build_deformable_conv, download_models
12
 
 
13
  # Prepare model checkpoints
14
  download_models()
15
  build_deformable_conv()
 
5
  import os
6
  import shutil
7
  import subprocess
8
+
9
  import gradio as gr
10
 
11
+ if os.getenv('SYSTEM') == 'spaces':
12
+ with open('patch') as f:
13
+ subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
+ shutil.copytree('prismer/helpers/images',
15
+ 'prismer/images',
16
+ dirs_exist_ok=True)
17
+
18
  from app_caption import create_demo as create_demo_caption
19
  from prismer_model import build_deformable_conv, download_models
20
 
21
+
22
  # Prepare model checkpoints
23
  download_models()
24
  build_deformable_conv()
app_caption.py CHANGED
@@ -15,8 +15,10 @@ def create_demo():
15
 
16
  with gr.Row():
17
  with gr.Column():
18
- image = gr.Image(label='Input Image', type='filepath')
19
- model_name = gr.Dropdown(label='Model Size', choices=['prismer_base'], value='prismer_base')
 
 
20
  run_button = gr.Button('Run')
21
  with gr.Column(scale=1.5):
22
  caption = gr.Text(label='Caption')
@@ -30,7 +32,15 @@ def create_demo():
30
  ocr = gr.Image(label='OCR Detection')
31
 
32
  inputs = [image, model_name]
33
- outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
 
 
 
 
 
 
 
 
34
 
35
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
36
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
 
15
 
16
  with gr.Row():
17
  with gr.Column():
18
+ image = gr.Image(label='Input', type='filepath')
19
+ model_name = gr.Dropdown(label='Model',
20
+ choices=['prismer_base'],
21
+ value='prismer_base')
22
  run_button = gr.Button('Run')
23
  with gr.Column(scale=1.5):
24
  caption = gr.Text(label='Caption')
 
32
  ocr = gr.Image(label='OCR Detection')
33
 
34
  inputs = [image, model_name]
35
+ outputs = [
36
+ caption,
37
+ depth,
38
+ edge,
39
+ normals,
40
+ segmentation,
41
+ object_detection,
42
+ ocr,
43
+ ]
44
 
45
  paths = sorted(pathlib.Path('prismer/images').glob('*'))
46
  examples = [[path.as_posix(), 'prismer_base'] for path in paths]
patch ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
2
+ index 266fdda..0cc5d3f 100644
3
+ --- a/dataset/caption_dataset.py
4
+ +++ b/dataset/caption_dataset.py
5
+ @@ -50,7 +50,7 @@ class Caption(Dataset):
6
+ elif self.dataset == 'demo':
7
+ img_path_split = self.data_list[index]['image'].split('/')
8
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
9
+ - image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
10
+ + image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
11
+
12
+ experts = self.transform(image, labels)
13
+ experts = post_label_process(experts, labels_info)
14
+ diff --git a/dataset/utils.py b/dataset/utils.py
15
+ index b368aac..418358c 100644
16
+ --- a/dataset/utils.py
17
+ +++ b/dataset/utils.py
18
+ @@ -5,6 +5,7 @@
19
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
20
+
21
+ import os
22
+ +import pathlib
23
+ import re
24
+ import json
25
+ import torch
26
+ @@ -14,10 +15,12 @@ import torchvision.transforms as transforms
27
+ import torchvision.transforms.functional as transforms_f
28
+ from dataset.randaugment import RandAugment
29
+
30
+ -COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
31
+ -ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
32
+ -DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
33
+ -BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
34
+ +cur_dir = pathlib.Path(__file__).parent
35
+ +
36
+ +COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
37
+ +ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
38
+ +DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
39
+ +BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
40
+
41
+
42
+ class Transform:
43
+ diff --git a/model/prismer.py b/model/prismer.py
44
+ index 080253a..02362a4 100644
45
+ --- a/model/prismer.py
46
+ +++ b/model/prismer.py
47
+ @@ -5,6 +5,7 @@
48
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
49
+
50
+ import json
51
+ +import pathlib
52
+ import torch.nn as nn
53
+
54
+ from model.modules.vit import load_encoder
55
+ @@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
56
+ from transformers import RobertaTokenizer, RobertaConfig
57
+
58
+
59
+ +cur_dir = pathlib.Path(__file__).parent
60
+ +
61
+ +
62
+ class Prismer(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ @@ -26,7 +30,7 @@ class Prismer(nn.Module):
66
+ elif exp in ['obj_detection', 'ocr_detection']:
67
+ self.experts[exp] = 64
68
+
69
+ - prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
70
+ + prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
71
+ roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
72
+
73
+ self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
74
+ @@ -35,7 +39,7 @@ class Prismer(nn.Module):
75
+
76
+ self.prepare_to_train(config['freeze'])
77
+ self.ignored_modules = self.get_ignored_modules(config['freeze'])
78
+ -
79
+ +
80
+ def prepare_to_train(self, mode='none'):
81
+ for name, params in self.named_parameters():
82
+ if mode == 'freeze_lang':
prismer/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ cache
3
+ .DS_Store
4
+ **/__pycache__/*
5
+ helpers/data/*
6
+ helpers/images2/*
7
+ helpers/labels/*
8
+ experts/expert_weights
9
+ logging/*
10
+ flagged/*
prismer/LICENSE ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2023, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for Prismer
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
prismer/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prismer
2
+
3
+ This repository contains the source code of **Prismer** and **PrismerZ** from the paper, [Prismer: A Vision-Language Model with An Ensemble of Experts](https://arxiv.org/abs/2303.02506).
4
+
5
+ <img src="helpers/intro.png" width="100%"/>
6
+
7
+ ## Get Started
8
+ The implementation is based on `PyTorch 1.13`, and highly integrated with Huggingface [`accelerate`](https://github.com/huggingface/accelerate) toolkit for readable and optimised multi-node multi-gpu training.
9
+
10
+ First, let's install all package dependencies by running
11
+ ```bash
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ ### Prepare Accelerator Config
16
+ Then we generate the corresponding `accelerate` config based on your training server configuration. For both single-node multi-gpu and multi-node multi-gpu training, simply run
17
+ ```bash
18
+ # to get your machine rank 0 IP address
19
+ hostname -i
20
+
21
+ # and for each machine, run the following command, set --num_machines 1 in a single-node setting
22
+ python generate_config.py —-main_ip {MAIN_IP} -—rank {MACHINE_RANK} —-num_machines {TOTAL_MACHINES}
23
+ ```
24
+
25
+ ## Datasets
26
+
27
+ ### Pre-training
28
+ We pre-train Prismer/PrismerZ with a combination of five widely used image-alt/text datasets, with pre-organised data lists provided below.
29
+ - [COCO 2014](https://www.dropbox.com/s/6btr8hz5n1e1q4d/coco_karpathy_train.json?dl=0): the Karpathy training split (which will also be used for fine-tuning).
30
+ - [Visual Genome](https://www.dropbox.com/s/kailbaay0sqraxc/vg_caption.json?dl=0): the official Visual Genome captioning dataset.
31
+ - [CC3M + SGU](https://www.dropbox.com/s/xp2nuhc88f1czxm/filtered_cc3m_sbu.json?dl=0): filtered and re-captioned by BLIP-Large.
32
+ - [CC12M](https://www.dropbox.com/s/th358bb6wqkpwbz/filtered_cc12m.json?dl=0): filtered and re-captioned by BLIP-Large.
33
+
34
+ The web datasets (CC3M, SGU, CC12M) is composed with image urls. It is highly recommended to use [img2dataset](https://github.com/rom1504/img2dataset), a highly optimised toolkit for large-scale web scraping to download these images. An example bash script of using `img2dataset` to download `cc12m` dataset is provided below.
35
+ ```bash
36
+ img2dataset --url_list filtered_cc12m.json --input_format "json" --url_col "url" --caption_col "caption" --output_folder cc12m --processes_count 16 --thread_count 64 --image_size 256
37
+ ```
38
+
39
+ *Note: It is expected that the number of downloaded images is less than the number of images in the json file, because some urls might not be valid or require long loading time.*
40
+
41
+ ### Image Captioning / VQA
42
+ We evaluate image captioning performance on two datasets, COCO 2014 and NoCaps; and VQA performance on VQAv2 dataset. In VQA tasks, we additionally augment the training data with Visual Genome QA, following BLIP. Again, we have prepared and organised the training and evaluation data lists provided below.
43
+
44
+ - [Image Captioning](https://www.dropbox.com/sh/quu6v5hzdetjcdz/AACze0_h6BO8LJmSsEq4MM8-a?dl=0): including COCO (Karpathy Split) and NoCaps.
45
+ - [VQAv2](https://www.dropbox.com/sh/hqtxl1k8gkbhhoi/AACiax5qi7no3pJgO1E57Xefa?dl=0): including VQAv2 and VG QA.
46
+
47
+ ## Generating Expert Labels
48
+ Before starting any experiments with Prismer, we need to first pre-generate the modality expert labels, so we may construct a multi-label dataset. In `experts` folder, we have included all 6 experts we introduced in our paper. We have organised each expert's codebase with a shared and simple API.
49
+
50
+ *Note: Specifically for segmentation experts, please first install deformable convolution operations by `cd experts/segmentation/mask2former/modeling/pixel_decoder/ops` and run `sh make.sh`.*
51
+
52
+ To download pre-trained modality experts, run
53
+ ```bash
54
+ python download_checkpoints.py --download_experts=True
55
+ ```
56
+
57
+ To generate the expert labels, simply edit the `configs/experts.yaml` with the corresponding data paths, and run
58
+ ```bash
59
+ export PYTHONPATH=.
60
+ accelerate experts/generate_{EXPERT_NAME}.py
61
+ ```
62
+ *Note: Expert label generation is only required for Prismer models, not for PrismerZ models.*
63
+
64
+ ## Experiments
65
+ We have provided both Prismer and PrismerZ for pre-trained checkpoints (for zero-shot image captioning), as well as fined-tuned checkpoints on VQAv2 and COCO datasets. With these checkpoints, it should be expected to reproduce the exact performance listed below.
66
+
67
+ | Model | Pre-trained [Zero-shot] | COCO [Fine-tuned] | VQAv2 [Fine-tuned] |
68
+ |----------------|-------------------------|---------------------|-------------------|
69
+ | PrismerZ-BASE | COCO CIDEr [109.6] | COCO CIDEr [133.7] | test-dev [76.58] |
70
+ | Prismer-BASE | COCO CIDEr [122.6] | COCO CIDEr [135.1] | test-dev [76.84] |
71
+ | PrismerZ-LARGE | COCO CIDEr [124.8] | COCO CIDEr [135.7] | test-dev [77.49] |
72
+ | Prismer-LARGE | COCO CIDEr [129.7] | COCO CIDEr [136.5] | test-dev [78.42] |
73
+
74
+ To download pre-trained/fined-tuned checkpoints, run
75
+ ```bash
76
+ # to download all model checkpoints (12 models in total)
77
+ python download_checkpoints.py --download_models=True
78
+
79
+ # to download specific checkpoints (Prismer-Base for fine-tuned VQA) in this example
80
+ python download_checkpoints.py --download_models="vqa_prismer_base"
81
+ ```
82
+
83
+
84
+ *Note: Remember to install java via `sudo apt-get install default-jre` which is required to run the official COCO caption evaluation scripts.*
85
+
86
+
87
+ ### Evaluation
88
+ To evaluate the model checkpoints, please run
89
+ ```bash
90
+ # zero-shot image captioning (remember to remove caption prefix in the config files)
91
+ python train_caption.py --exp_name {MODEL_NAME} --evaluate
92
+
93
+ # fine-tuned image captioning
94
+ python train_caption.py --exp_name {MODEL_NAME} --from_checkpoint --evaluate
95
+
96
+ # fine-tuned VQA
97
+ python train_vqa.py --exp_name {MODEL_NAME} --from_checkpoint --evaluate
98
+ ```
99
+
100
+ ### Training / Fine-tuning
101
+ To pre-train or fine-tune any model with or without checkpoints, please run
102
+ ```bash
103
+ # to train/fine-tuning from scratch
104
+ python train_{TASK}.py --exp_name {MODEL_NAME}
105
+
106
+ # to train/fine-tuning from the latest checkpoints (saved every epoch)
107
+ python train_{TASK}.py --exp_name {MODEL_NAME} --from_checkpoint
108
+ ```
109
+
110
+ We have also included model sharding in the current training script via PyTorch's official [FSDP plugin](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html). With the same training commands, additionally add `--shard_grad_op` for ZeRO-2 Sharding (Gradients + Optimiser States), or `--full_shard` for ZeRO-3 Sharding (ZeRO-2 + Network Parameters).
111
+
112
+ *Note: You should expect the error range for VQAv2 Acc. to be less than 0.1; for COCO/NoCAPs CIDEr score to be less than 1.0.*
113
+
114
+ ## Demo
115
+ Finally, we have offered a minimalist example to perform image captioning in a single GPU with our fine-tuned Prismer/PrismerZ checkpoint. Simply put your images under `helpers/images` (`.jpg` images), and run
116
+ ```bash
117
+ python demo.py --exp_name {MODEL_NAME}
118
+ ```
119
+
120
+ You then can see all generated modality expert labels in the `helpers/labels` folder and the generated captions in the `helpers/images` folder.
121
+
122
+ Particularly for the Prismer models, we have also offered a simple script to prettify the generated expert labels. To prettify and visualise the expert labels as well as its predicted captions, run
123
+ ```bash
124
+ python demo_vis.py
125
+ ```
126
+
127
+ *Note: Remember to set up the corresponding config in the `configs/caption.yaml` demo section. The default demo model config is for Prismer-Base.*
128
+
129
+ ## Citation
130
+
131
+ If you found this code/work to be useful in your own research, please considering citing the following:
132
+
133
+
134
+ ```bibtex
135
+ @article{liu2023prismer,
136
+ title={Prismer: A Vision-Language Model with An Ensemble of Experts},
137
+ author={Liu, Shikun and Fan, Linxi and Johns, Edward and Yu, Zhiding and Xiao, Chaowei and Anandkumar, Anima},
138
+ journal={arXiv preprint arXiv:2303.02506},
139
+ year={2023}
140
+ }
141
+ ```
142
+
143
+ ## License
144
+ Copyright © 2023, NVIDIA Corporation. All rights reserved.
145
+
146
+ This work is made available under the Nvidia Source Code License-NC.
147
+
148
+ The model checkpoints are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
149
+
150
+ For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
151
+
152
+ ## Acknowledgement
153
+ We would like to thank all the researchers who open source their works to make this project possible. [@bjoernpl](https://github.com/bjoernpl) for contributing an automated checkpoint download script.
154
+
155
+ ## Contact
156
+ If you have any questions, please contact `sk.lorenmt@gmail.com`.
prismer/dataset/__init__.py CHANGED
@@ -6,12 +6,18 @@
6
 
7
  from torch.utils.data import DataLoader
8
 
 
9
  from dataset.vqa_dataset import VQA
10
  from dataset.caption_dataset import Caption
 
11
 
12
 
13
  def create_dataset(dataset, config):
14
- if dataset == 'vqa':
 
 
 
 
15
  train_dataset = VQA(config, train=True)
16
  test_dataset = VQA(config, train=False)
17
  return train_dataset, test_dataset
@@ -20,6 +26,11 @@ def create_dataset(dataset, config):
20
  train_dataset = Caption(config, train=True)
21
  test_dataset = Caption(config, train=False)
22
  return train_dataset, test_dataset
 
 
 
 
 
23
 
24
 
25
  def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
 
6
 
7
  from torch.utils.data import DataLoader
8
 
9
+ from dataset.pretrain_dataset import Pretrain
10
  from dataset.vqa_dataset import VQA
11
  from dataset.caption_dataset import Caption
12
+ from dataset.classification_dataset import Classification
13
 
14
 
15
  def create_dataset(dataset, config):
16
+ if dataset == 'pretrain':
17
+ dataset = Pretrain(config)
18
+ return dataset
19
+
20
+ elif dataset == 'vqa':
21
  train_dataset = VQA(config, train=True)
22
  test_dataset = VQA(config, train=False)
23
  return train_dataset, test_dataset
 
26
  train_dataset = Caption(config, train=True)
27
  test_dataset = Caption(config, train=False)
28
  return train_dataset, test_dataset
29
+
30
+ elif dataset == 'classification':
31
+ train_dataset = Classification(config, train=True)
32
+ test_dataset = Classification(config, train=False)
33
+ return train_dataset, test_dataset
34
 
35
 
36
  def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
prismer/dataset/ade_features.pt CHANGED
Binary files a/prismer/dataset/ade_features.pt and b/prismer/dataset/ade_features.pt differ
 
prismer/dataset/background_features.pt CHANGED
Binary files a/prismer/dataset/background_features.pt and b/prismer/dataset/background_features.pt differ
 
prismer/dataset/caption_dataset.py CHANGED
@@ -50,7 +50,7 @@ class Caption(Dataset):
50
  elif self.dataset == 'demo':
51
  img_path_split = self.data_list[index]['image'].split('/')
52
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
- image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
54
 
55
  experts = self.transform(image, labels)
56
  experts = post_label_process(experts, labels_info)
 
50
  elif self.dataset == 'demo':
51
  img_path_split = self.data_list[index]['image'].split('/')
52
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
+ image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
54
 
55
  experts = self.transform(image, labels)
56
  experts = post_label_process(experts, labels_info)
prismer/dataset/classification_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+ from torch.utils.data import Dataset
9
+ from dataset.utils import *
10
+
11
+
12
+ class Classification(Dataset):
13
+ def __init__(self, config, train):
14
+ self.data_path = config['data_path']
15
+ self.label_path = config['label_path']
16
+ self.experts = config['experts']
17
+ self.dataset = config['dataset']
18
+ self.shots = config['shots']
19
+ self.prefix = config['prefix']
20
+
21
+ self.train = train
22
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=True)
23
+
24
+ if train:
25
+ data_folders = glob.glob(f'{self.data_path}/imagenet_train/*/')
26
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')[:self.shots]]
27
+ self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
28
+ self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
29
+ else:
30
+ data_folders = glob.glob(f'{self.data_path}/imagenet/*/')
31
+ self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')]
32
+ self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
33
+ self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
34
+
35
+ def __len__(self):
36
+ return len(self.data_list)
37
+
38
+ def __getitem__(self, index):
39
+ img_path = self.data_list[index]['image']
40
+ if self.train:
41
+ img_path_split = img_path.split('/')
42
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
43
+ class_name = img_path_split[-2]
44
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet_train', self.experts)
45
+ else:
46
+ img_path_split = img_path.split('/')
47
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
48
+ class_name = img_path_split[-2]
49
+ image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet', self.experts)
50
+
51
+ experts = self.transform(image, labels)
52
+ experts = post_label_process(experts, labels_info)
53
+
54
+ if self.train:
55
+ caption = self.prefix + ' ' + self.answer_list[int(self.class_list[class_name])].lower()
56
+ return experts, caption
57
+ else:
58
+ return experts, self.class_list[class_name]
59
+
60
+
61
+
62
+
63
+
64
+ # import os
65
+ # import glob
66
+ #
67
+ # data_path = '/Users/shikunliu/Documents/dataset/mscoco/mscoco'
68
+ #
69
+ # data_folders = glob.glob(f'{data_path}/*/')
70
+ # data_list = [data for f in data_folders for data in glob.glob(f + '*.jpg')]
71
+
72
+
prismer/dataset/clip_pca.pkl CHANGED
Binary files a/prismer/dataset/clip_pca.pkl and b/prismer/dataset/clip_pca.pkl differ
 
prismer/dataset/coco_features.pt CHANGED
Binary files a/prismer/dataset/coco_features.pt and b/prismer/dataset/coco_features.pt differ
 
prismer/dataset/detection_features.pt CHANGED
Binary files a/prismer/dataset/detection_features.pt and b/prismer/dataset/detection_features.pt differ
 
prismer/dataset/pretrain_dataset.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import glob
8
+
9
+ from torch.utils.data import Dataset
10
+ from dataset.utils import *
11
+
12
+
13
+ class Pretrain(Dataset):
14
+ def __init__(self, config):
15
+ self.cc12m_data_path = config['cc12m_data_path']
16
+ self.cc3m_data_path = config['cc3m_data_path']
17
+ self.coco_data_path = config['coco_data_path']
18
+ self.vg_data_path = config['vg_data_path']
19
+ self.label_path = config['label_path']
20
+ self.experts = config['experts']
21
+
22
+ self.data_list = []
23
+ if 'cc12m' in config['datasets']:
24
+ data_folders = glob.glob(f'{self.cc12m_data_path}/cc12m/*/')
25
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
26
+ if 'cc3m_sgu' in config['datasets']:
27
+ data_folders = glob.glob(f'{self.cc3m_data_path}/cc3m_sgu/*/')
28
+ self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
29
+ if 'coco' in config['datasets']:
30
+ self.data_list += json.load(open(os.path.join(self.coco_data_path, 'coco_karpathy_train.json'), 'r'))
31
+ if 'vg' in config['datasets']:
32
+ self.data_list += json.load(open(os.path.join(self.vg_data_path, 'vg_caption.json'), 'r'))
33
+
34
+ self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.5], train=True)
35
+
36
+ def __len__(self):
37
+ return len(self.data_list)
38
+
39
+ def __getitem__(self, index):
40
+ img_path = self.data_list[index]['image']
41
+
42
+ if 'cc12m' in img_path:
43
+ img_path_split = img_path.split('/')
44
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
45
+ image, labels, labels_info = get_expert_labels(self.cc12m_data_path, self.label_path, img_name, 'cc12m', self.experts)
46
+
47
+ caption_path = img_path.replace('.jpg', '.txt')
48
+ with open(caption_path) as f:
49
+ caption = f.readlines()[0]
50
+
51
+ elif 'cc3m_sgu' in img_path:
52
+ img_path_split = img_path.split('/')
53
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
54
+ image, labels, labels_info = get_expert_labels(self.cc3m_data_path, self.label_path, img_name, 'cc3m_sgu', self.experts)
55
+
56
+ caption_path = img_path.replace('.jpg', '.txt')
57
+ with open(caption_path) as f:
58
+ caption = f.readlines()[0]
59
+
60
+ elif 'train2014' in img_path or 'val2014' in img_path:
61
+ image, labels, labels_info = get_expert_labels(self.coco_data_path, self.label_path, img_path, 'vqav2', self.experts)
62
+ caption = self.data_list[index]['caption']
63
+
64
+ elif 'visual-genome' in img_path:
65
+ img_path_split = img_path.split('/')
66
+ img_name = img_path_split[-2] + '/' + img_path_split[-1]
67
+ image, labels, labels_info = get_expert_labels(self.vg_data_path, self.label_path, img_name, 'vg', self.experts)
68
+ caption = self.data_list[index]['caption']
69
+
70
+ experts = self.transform(image, labels)
71
+ experts = post_label_process(experts, labels_info)
72
+ caption = pre_caption(caption, max_words=30)
73
+ return experts, caption
prismer/dataset/utils.py CHANGED
@@ -12,16 +12,12 @@ import PIL.Image as Image
12
  import numpy as np
13
  import torchvision.transforms as transforms
14
  import torchvision.transforms.functional as transforms_f
15
- import pathlib
16
  from dataset.randaugment import RandAugment
17
 
18
-
19
- cur_dir = pathlib.Path(__file__).parent
20
-
21
- COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
22
- ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
23
- DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
24
- BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
25
 
26
 
27
  class Transform:
 
12
  import numpy as np
13
  import torchvision.transforms as transforms
14
  import torchvision.transforms.functional as transforms_f
 
15
  from dataset.randaugment import RandAugment
16
 
17
+ COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
18
+ ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
19
+ DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
20
+ BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
 
 
 
21
 
22
 
23
  class Transform:
prismer/dataset/vqa_dataset.py CHANGED
@@ -6,8 +6,6 @@
6
 
7
  from torch.utils.data import Dataset
8
  from dataset.utils import *
9
- from PIL import ImageFile
10
- ImageFile.LOAD_TRUNCATED_IMAGES = True
11
 
12
 
13
  class VQA(Dataset):
 
6
 
7
  from torch.utils.data import Dataset
8
  from dataset.utils import *
 
 
9
 
10
 
11
  class VQA(Dataset):
prismer/experts/generate_depth.py CHANGED
@@ -20,7 +20,7 @@ from tqdm import tqdm
20
  model, transform = load_expert_model(task='depth')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
  data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'depth')
26
 
 
20
  model, transform = load_expert_model(task='depth')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
  data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'depth')
26
 
prismer/experts/generate_edge.py CHANGED
@@ -22,7 +22,7 @@ from tqdm import tqdm
22
  model, transform = load_expert_model(task='edge')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'edge')
28
 
 
22
  model, transform = load_expert_model(task='edge')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'edge')
28
 
prismer/experts/generate_normal.py CHANGED
@@ -22,7 +22,7 @@ import numpy as np
22
  model, transform = load_expert_model(task='normal')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'normal')
28
 
 
22
  model, transform = load_expert_model(task='normal')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = os.path.join(config['save_path'], 'normal')
28
 
prismer/experts/generate_objdet.py CHANGED
@@ -22,7 +22,7 @@ from tqdm import tqdm
22
  model, transform = load_expert_model(task='obj_detection')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = config['save_path']
28
 
 
22
  model, transform = load_expert_model(task='obj_detection')
23
  accelerator = Accelerator(mixed_precision='fp16')
24
 
25
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
26
  data_path = config['data_path']
27
  save_path = config['save_path']
28
 
prismer/experts/generate_ocrdet.py CHANGED
@@ -26,7 +26,7 @@ model, transform = load_expert_model(task='ocr_detection')
26
  accelerator = Accelerator(mixed_precision='fp16')
27
  pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
 
29
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
30
  data_path = config['data_path']
31
  save_path = os.path.join(config['save_path'], 'ocr_detection')
32
 
 
26
  accelerator = Accelerator(mixed_precision='fp16')
27
  pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
28
 
29
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
30
  data_path = config['data_path']
31
  save_path = os.path.join(config['save_path'], 'ocr_detection')
32
 
prismer/experts/generate_segmentation.py CHANGED
@@ -20,7 +20,7 @@ from tqdm import tqdm
20
  model, transform = load_expert_model(task='seg_coco')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
- config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
  data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'seg_coco')
26
 
 
20
  model, transform = load_expert_model(task='seg_coco')
21
  accelerator = Accelerator(mixed_precision='fp16')
22
 
23
+ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
24
  data_path = config['data_path']
25
  save_path = os.path.join(config['save_path'], 'seg_coco')
26
 
prismer/{images → helpers/images}/COCO_test2015_000000000014.jpg RENAMED
File without changes
prismer/{images → helpers/images}/COCO_test2015_000000000016.jpg RENAMED
File without changes
prismer/{images → helpers/images}/COCO_test2015_000000000019.jpg RENAMED
File without changes
prismer/{images → helpers/images}/COCO_test2015_000000000128.jpg RENAMED
File without changes
prismer/{images → helpers/images}/COCO_test2015_000000000155.jpg RENAMED
File without changes
prismer/helpers/intro.png ADDED
prismer/model/prismer.py CHANGED
@@ -5,15 +5,12 @@
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
8
- import pathlib
9
  import torch.nn as nn
10
 
11
  from model.modules.vit import load_encoder
12
  from model.modules.roberta import load_decoder
13
  from transformers import RobertaTokenizer, RobertaConfig
14
 
15
- cur_dir = pathlib.Path(__file__).parent
16
-
17
 
18
  class Prismer(nn.Module):
19
  def __init__(self, config):
@@ -29,7 +26,7 @@ class Prismer(nn.Module):
29
  elif exp in ['obj_detection', 'ocr_detection']:
30
  self.experts[exp] = 64
31
 
32
- prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
33
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
34
 
35
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
 
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
 
8
  import torch.nn as nn
9
 
10
  from model.modules.vit import load_encoder
11
  from model.modules.roberta import load_decoder
12
  from transformers import RobertaTokenizer, RobertaConfig
13
 
 
 
14
 
15
  class Prismer(nn.Module):
16
  def __init__(self, config):
 
26
  elif exp in ['obj_detection', 'ocr_detection']:
27
  self.experts[exp] = 64
28
 
29
+ prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
30
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
31
 
32
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
prismer/requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/CLIP.git
2
+ git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13
3
+ accelerate
4
+ fairscale
5
+ timm
6
+ transformers
7
+ einops
8
+ scikit-learn==0.24.2
9
+ pycocoevalcap
10
+ editdistance
11
+ shapely
12
+ pyclipper
13
+ yacs
14
+ pycocotools
15
+ geffnet
16
+ fire
17
+ huggingface_hub
18
+ rich
19
+ ruamel.yaml
prismer/train_caption.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import argparse
8
+ import numpy as np
9
+ import random
10
+ import time
11
+ import functools
12
+ import json
13
+ import torch
14
+ import os
15
+ try:
16
+ import ruamel_yaml as yaml
17
+ except ModuleNotFoundError:
18
+ import ruamel.yaml as yaml
19
+
20
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
21
+ from model.prismer_caption import PrismerCaption
22
+ from model.modules.utils import interpolate_pos_embed
23
+ from dataset import create_dataset, create_loader
24
+ from utils import *
25
+ from tqdm import tqdm
26
+
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument('--mode', default='')
30
+ parser.add_argument('--port', default='')
31
+
32
+ parser.add_argument('--config', default='configs/caption.yaml')
33
+ parser.add_argument('--from_checkpoint', action='store_true')
34
+ parser.add_argument('--evaluate', action='store_true')
35
+ parser.add_argument('--target_dataset', default='coco', type=str)
36
+ parser.add_argument('--shard_grad_op', action='store_true')
37
+ parser.add_argument('--full_shard', action='store_true')
38
+ parser.add_argument('--exp_name', default='', type=str)
39
+ parser.add_argument('--mixed_precision', default='fp16', type=str)
40
+ parser.add_argument('--seed', default=42, type=int)
41
+ args = parser.parse_args()
42
+
43
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)[args.target_dataset]
44
+ torch.manual_seed(args.seed)
45
+ np.random.seed(args.seed)
46
+ random.seed(args.seed)
47
+
48
+ train_dataset, test_dataset = create_dataset('caption', config)
49
+ train_loader = create_loader(train_dataset, batch_size=config['batch_size_train'], num_workers=8, train=True)
50
+ test_loader = create_loader(test_dataset, batch_size=config['batch_size_test'], num_workers=8, train=False)
51
+
52
+
53
+ model = PrismerCaption(config)
54
+ tokenizer = model.tokenizer
55
+
56
+ if args.shard_grad_op: # Model Sharding: ZeRO 2
57
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
58
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
59
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
60
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
61
+ reduce_dtype=torch.float16,
62
+ buffer_dtype=torch.float16),
63
+ state_dict_type=StateDictType.FULL_STATE_DICT,
64
+ ignored_modules=model.ignored_modules)
65
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
66
+ model = accelerator.prepare(model)
67
+
68
+ elif args.full_shard: # Model Sharding: ZeRO 3
69
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
70
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
71
+ from model.modules.vit import ResidualAttentionBlock
72
+ from model.modules.resampler import PerceiverAttentionBlock
73
+ from model.modules.roberta import RobertaLayer
74
+ auto_wrap_policy = functools.partial(
75
+ transformer_auto_wrap_policy,
76
+ transformer_layer_cls={
77
+ ResidualAttentionBlock,
78
+ PerceiverAttentionBlock,
79
+ RobertaLayer
80
+ },
81
+ )
82
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.FULL_SHARD,
83
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
84
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
85
+ reduce_dtype=torch.float16,
86
+ buffer_dtype=torch.float16),
87
+ state_dict_type=StateDictType.FULL_STATE_DICT,
88
+ auto_wrap_policy=auto_wrap_policy,
89
+ ignored_modules=model.ignored_modules)
90
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
91
+ model = accelerator.prepare(model)
92
+ else:
93
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
94
+
95
+ # Reload saved states
96
+ if not args.from_checkpoint:
97
+ state_dict = torch.load(f'logging/pretrain_{args.exp_name}/pytorch_model.bin', map_location='cpu')
98
+ state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
99
+ len(model.expert_encoder.positional_embedding))
100
+ model.load_state_dict(state_dict)
101
+ start_epoch = 0
102
+ else:
103
+ state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cpu')
104
+ if os.path.exists(f'logging/caption_{args.exp_name}/epoch.pt'):
105
+ start_epoch = torch.load(f'logging/caption_{args.exp_name}/epoch.pt')[0] + 1
106
+ else:
107
+ start_epoch = 0
108
+ model.load_state_dict(state_dict)
109
+ accelerator.print(f'Start re-training from checkpoint with Epoch {start_epoch}')
110
+
111
+ optimizer = torch.optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
112
+ lr=config['init_lr'], weight_decay=config['weight_decay'])
113
+
114
+ if args.shard_grad_op or args.full_shard:
115
+ optimizer, train_loader, test_loader = accelerator.prepare(optimizer, train_loader, test_loader)
116
+ else:
117
+ model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)
118
+
119
+ best = 0
120
+ start_time = time.time()
121
+ if not args.evaluate:
122
+ for epoch in range(start_epoch, config['max_epoch']):
123
+ train_loss = 0
124
+ num_train_elems = 0
125
+ model.train()
126
+ for i, (experts, caption) in enumerate(tqdm(train_loader)):
127
+ cosine_lr_schedule(optimizer, epoch * len(train_loader) + i, config['max_epoch'] * len(train_loader), config['init_lr'], config['min_lr'])
128
+
129
+ loss = model(experts, caption, prefix=config['prefix'])
130
+
131
+ optimizer.zero_grad()
132
+ accelerator.backward(loss)
133
+ optimizer.step()
134
+
135
+ train_loss += loss.item()
136
+ num_train_elems += 1
137
+
138
+ model.eval()
139
+ result = []
140
+ with torch.no_grad():
141
+ for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
142
+ captions = model(experts, train=False, prefix=config['prefix'])
143
+
144
+ if accelerator.use_distributed:
145
+ captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
146
+ captions = captions.to(experts['rgb'].device)
147
+ data_ids, captions = accelerator.gather_for_metrics((data_ids, captions))
148
+
149
+ for data_id, caption in zip(data_ids, captions):
150
+ caption = tokenizer.decode(caption, skip_special_tokens=True)
151
+ if args.target_dataset == 'coco':
152
+ image_id = int(test_loader.dataset.data_list[data_id]['image'].split('/')[-1].strip('.jpg').split('_')[-1])
153
+ result.append({"image_id": image_id, "caption": caption.capitalize() + '.'})
154
+ elif args.target_dataset == 'nocaps':
155
+ result.append({"image_id": test_loader.dataset.data_list[data_id]['img_id'],
156
+ "caption": caption.capitalize() + '.'})
157
+
158
+ accelerator.wait_for_everyone()
159
+ if accelerator.is_main_process:
160
+ json.dump(result, open(f'/results/caption_results_{args.exp_name}_{args.target_dataset}.json', 'w'))
161
+ if args.target_dataset == 'coco':
162
+ coco_eval = coco_caption_eval(f'{config["data_path"]}/coco_karpathy_test_gt.json', result)
163
+ torch.save([coco_eval.eval['CIDEr']], f'logging/caption_{args.exp_name}/temp_cider.pt')
164
+ if not os.path.isfile(f'logging/caption_{args.exp_name}/cider.pt'):
165
+ torch.save([coco_eval.eval['CIDEr']], f'logging/caption_{args.exp_name}/cider.pt')
166
+
167
+ accelerator.wait_for_everyone()
168
+ cider = torch.load(f'logging/caption_{args.exp_name}/cider.pt')[0]
169
+ curr_cider = torch.load(f'logging/caption_{args.exp_name}/temp_cider.pt')[0]
170
+
171
+ if cider < curr_cider:
172
+ train_loss /= num_train_elems
173
+ accelerator.print(f"Epoch {epoch:03d} | loss: {train_loss:.4f} || Time: {(time.time() - start_time):.4f}")
174
+ accelerator.save_state(f'logging/caption_{args.exp_name}')
175
+ accelerator.save([epoch], f'logging/caption_{args.exp_name}/epoch.pt')
176
+ accelerator.save([curr_cider], f'logging/caption_{args.exp_name}/cider.pt')
177
+
178
+
179
+ model.eval()
180
+ if accelerator.is_main_process:
181
+ result = []
182
+
183
+ with torch.no_grad():
184
+ for step, (experts, data_ids) in enumerate(tqdm(test_loader)):
185
+ captions = model(experts, train=False, prefix=config['prefix'])
186
+
187
+ if accelerator.use_distributed:
188
+ captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
189
+ captions = captions.to(experts['rgb'].device)
190
+ data_ids, captions = accelerator.gather_for_metrics((data_ids, captions))
191
+
192
+ if accelerator.is_main_process:
193
+ for data_id, caption in zip(data_ids, captions):
194
+ caption = tokenizer.decode(caption, skip_special_tokens=True)
195
+ if args.target_dataset == 'coco':
196
+ image_id = int(test_loader.dataset.data_list[data_id]['image'].split('/')[-1].strip('.jpg').split('_')[-1])
197
+ result.append({"image_id": image_id, "caption": caption.capitalize() + '.'})
198
+ elif args.target_dataset == 'nocaps':
199
+ result.append({"image_id": test_loader.dataset.data_list[data_id]['img_id'],
200
+ "caption": caption.capitalize() + '.'})
201
+
202
+ accelerator.wait_for_everyone()
203
+ if accelerator.is_main_process:
204
+ json.dump(result, open(f'/results/caption_results_{args.exp_name}_{args.target_dataset}.json', 'w'))
205
+ if args.target_dataset == 'coco':
206
+ coco_caption_eval(f'{config["data_path"]}/coco_karpathy_test_gt.json', result)
207
+
208
+
prismer/train_classification.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import argparse
8
+ import numpy as np
9
+ import random
10
+ import time
11
+ import functools
12
+ import torch
13
+ try:
14
+ import ruamel_yaml as yaml
15
+ except ModuleNotFoundError:
16
+ import ruamel.yaml as yaml
17
+
18
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
19
+ from model.prismer_caption import PrismerCaption
20
+ from model.modules.utils import interpolate_pos_embed
21
+ from dataset import create_dataset, create_loader
22
+ from tqdm import tqdm
23
+ from utils import *
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--mode', default='')
27
+ parser.add_argument('--port', default='')
28
+
29
+ parser.add_argument('--config', default='configs/classification.yaml')
30
+ parser.add_argument('--from_checkpoint', action='store_true')
31
+ parser.add_argument('--evaluate', action='store_true')
32
+ parser.add_argument('--exp_name', default='', type=str)
33
+ parser.add_argument('--shard_grad_op', action='store_true')
34
+ parser.add_argument('--full_shard', action='store_true')
35
+ parser.add_argument('--mixed_precision', default='fp16', type=str)
36
+ parser.add_argument('--seed', default=42, type=int)
37
+ args = parser.parse_args()
38
+
39
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
40
+ torch.manual_seed(args.seed)
41
+ np.random.seed(args.seed)
42
+ random.seed(args.seed)
43
+
44
+ train_dataset, test_dataset = create_dataset('classification', config)
45
+ train_loader = create_loader(train_dataset, batch_size=config['batch_size_train'], num_workers=8, train=True)
46
+ test_loader = create_loader(test_dataset, batch_size=config['batch_size_test'], num_workers=8, train=False)
47
+ model = PrismerCaption(config)
48
+
49
+ if args.shard_grad_op: # Model Sharding: ZeRO 2
50
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
51
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
52
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
53
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
54
+ reduce_dtype=torch.float16,
55
+ buffer_dtype=torch.float16),
56
+ state_dict_type=StateDictType.FULL_STATE_DICT,
57
+ ignored_modules=model.ignored_modules)
58
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
59
+ model = accelerator.prepare(model)
60
+
61
+ elif args.full_shard: # Model Sharding: ZeRO 3
62
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
63
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
64
+ from model.modules.vit import ResidualAttentionBlock
65
+ from model.modules.resampler import PerceiverAttentionBlock
66
+ from model.modules.roberta import RobertaLayer
67
+ auto_wrap_policy = functools.partial(
68
+ transformer_auto_wrap_policy,
69
+ transformer_layer_cls={
70
+ ResidualAttentionBlock,
71
+ PerceiverAttentionBlock,
72
+ RobertaLayer
73
+ },
74
+ )
75
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.FULL_SHARD,
76
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
77
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
78
+ reduce_dtype=torch.float16,
79
+ buffer_dtype=torch.float16),
80
+ state_dict_type=StateDictType.FULL_STATE_DICT,
81
+ auto_wrap_policy=auto_wrap_policy,
82
+ ignored_modules=model.ignored_modules)
83
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
84
+ model = accelerator.prepare(model)
85
+ else:
86
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
87
+
88
+ # Reload saved states
89
+ if not args.from_checkpoint:
90
+ state_dict = torch.load(f'logging/pretrain_{args.exp_name}/pytorch_model.bin', map_location='cpu')
91
+ state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
92
+ len(model.expert_encoder.positional_embedding))
93
+ model.load_state_dict(state_dict)
94
+ start_epoch = 0
95
+ else:
96
+ state_dict = torch.load(f'logging/classification_{args.exp_name}/pytorch_model.bin', map_location='cpu')
97
+ if os.path.exists(f'logging/classification_{args.exp_name}/epoch.pt'):
98
+ start_epoch = torch.load(f'logging/classification_{args.exp_name}/epoch.pt')[0] + 1
99
+ else:
100
+ start_epoch = 0
101
+ model.load_state_dict(state_dict)
102
+ accelerator.print(f'Start re-training from checkpoint with Epoch {start_epoch}')
103
+
104
+ optimizer = torch.optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
105
+ lr=config['init_lr'], weight_decay=config['weight_decay'])
106
+
107
+ if args.shard_grad_op or args.full_shard:
108
+ optimizer, train_loader, test_loader = accelerator.prepare(optimizer, train_loader, test_loader)
109
+ else:
110
+ model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)
111
+
112
+ start_time = time.time()
113
+ best = 0
114
+ for epoch in range(start_epoch, config['max_epoch']):
115
+ train_loss = 0
116
+ num_train_elems = 0
117
+ model.train()
118
+ for i, (experts, caption) in enumerate(tqdm(train_loader)):
119
+ cosine_lr_schedule(optimizer, epoch * len(train_loader) + i, config['max_epoch'] * len(train_loader), config['init_lr'], config['min_lr'])
120
+ loss = model(experts, caption, prefix=config['prefix'])
121
+
122
+ optimizer.zero_grad()
123
+ accelerator.backward(loss)
124
+ optimizer.step()
125
+
126
+ train_loss += loss.item()
127
+ num_train_elems += 1
128
+
129
+ train_loss /= num_train_elems
130
+ accelerator.print(f"Epoch {epoch:03d} | loss: {train_loss:.4f} || Time: {(time.time() - start_time):.4f}")
131
+
132
+ if (epoch + 1) % 5 == 0:
133
+ model.eval()
134
+ num_test_elems = 0
135
+ accurate = 0
136
+ with torch.no_grad():
137
+ answer_list = test_loader.dataset.answer_list
138
+ for step, (experts, gt) in enumerate(tqdm(test_loader)):
139
+ predictions = model(experts, answer=answer_list, train=False, prefix=config['prefix'], k_test=config['k_test'], inference='rank')
140
+
141
+ if accelerator.use_distributed:
142
+ predictions, gt = accelerator.gather_for_metrics((predictions, gt))
143
+
144
+ accurate_preds = predictions == gt
145
+ num_test_elems += accurate_preds.shape[0]
146
+ accurate += accurate_preds.long().sum()
147
+ eval_metric = accurate.item() / num_test_elems
148
+
149
+ accelerator.wait_for_everyone()
150
+ accelerator.print(f'{config["shots"]}-Shot Acc: {eval_metric}')
151
+
152
+ if eval_metric > best:
153
+ best = eval_metric
154
+ accelerator.save_state(f'logging/classification_{args.exp_name}')
155
+ accelerator.save([epoch], f'logging/classification_{args.exp_name}/epoch.pt')
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
prismer/train_pretrain.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import argparse
8
+ import numpy as np
9
+ import random
10
+ import time
11
+ import datetime
12
+ import functools
13
+ import torch
14
+ try:
15
+ import ruamel_yaml as yaml
16
+ except ModuleNotFoundError:
17
+ import ruamel.yaml as yaml
18
+
19
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
20
+ from model.prismer_caption import PrismerCaption
21
+ from dataset import create_dataset, create_loader
22
+ from utils import *
23
+ from tqdm import tqdm
24
+
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--mode', default='')
28
+ parser.add_argument('--port', default='')
29
+
30
+ parser.add_argument('--config', default='configs/pretrain.yaml')
31
+ parser.add_argument('--from_checkpoint', action='store_true')
32
+ parser.add_argument('--shard_grad_op', action='store_true')
33
+ parser.add_argument('--full_shard', action='store_true')
34
+ parser.add_argument('--exp_name', default='', type=str)
35
+ parser.add_argument('--mixed_precision', default='fp16', type=str)
36
+ parser.add_argument('--seed', default=42, type=int)
37
+ args = parser.parse_args()
38
+
39
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
40
+ torch.manual_seed(args.seed)
41
+ np.random.seed(args.seed)
42
+ random.seed(args.seed)
43
+
44
+ train_dataset = create_dataset('pretrain', config)
45
+ train_loader = create_loader(train_dataset, batch_size=config['batch_size_train'], num_workers=8, train=True)
46
+
47
+ model = PrismerCaption(config)
48
+ if args.shard_grad_op: # Model Sharding: ZeRO 2
49
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType, CPUOffload
50
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
51
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
52
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
53
+ reduce_dtype=torch.float16,
54
+ buffer_dtype=torch.float16),
55
+ state_dict_type=StateDictType.FULL_STATE_DICT,
56
+ cpu_offload=CPUOffload(offload_params=False),
57
+ ignored_modules=model.ignored_modules)
58
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
59
+ model = accelerator.prepare(model)
60
+
61
+ elif args.full_shard: # Model Sharding: ZeRO 3
62
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
63
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
64
+ from model.modules.vit import ResidualAttentionBlock
65
+ from model.modules.resampler import PerceiverAttentionBlock
66
+ from model.modules.roberta import RobertaLayer
67
+ auto_wrap_policy = functools.partial(
68
+ transformer_auto_wrap_policy,
69
+ transformer_layer_cls={
70
+ ResidualAttentionBlock,
71
+ PerceiverAttentionBlock,
72
+ RobertaLayer
73
+ },
74
+ )
75
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.FULL_SHARD,
76
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
77
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
78
+ reduce_dtype=torch.float16,
79
+ buffer_dtype=torch.float16),
80
+ state_dict_type=StateDictType.FULL_STATE_DICT,
81
+ auto_wrap_policy=auto_wrap_policy,
82
+ ignored_modules=model.ignored_modules)
83
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
84
+ model = accelerator.prepare(model)
85
+ else:
86
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
87
+
88
+ # Reload saved states
89
+ if args.from_checkpoint:
90
+ state_dict = torch.load(f'logging/pretrain_{args.exp_name}/pytorch_model.bin', map_location='cpu')
91
+ if os.path.exists(f'logging/pretrain_{args.exp_name}/epoch.pt'):
92
+ start_epoch = torch.load(f'logging/pretrain_{args.exp_name}/epoch.pt')[0] + 1
93
+ else:
94
+ start_epoch = 0
95
+ model.load_state_dict(state_dict)
96
+ accelerator.print(f'Start re-training from checkpoint with Epoch {start_epoch}')
97
+ else:
98
+ start_epoch = 0
99
+
100
+ optimizer = torch.optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
101
+ lr=config['init_lr'], weight_decay=config['weight_decay'])
102
+
103
+ if args.shard_grad_op or args.full_shard:
104
+ optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
105
+ else:
106
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
107
+
108
+
109
+ start_time = time.time()
110
+ warmup_step = 0
111
+ for epoch in range(start_epoch, config['max_epoch']):
112
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
113
+
114
+ train_loss = 0
115
+ num_train_elems = 0
116
+ model.train()
117
+ for i, (experts, caption) in enumerate(tqdm(train_loader)):
118
+ if warmup_step < config['warmup_steps']:
119
+ warmup_lr_schedule(optimizer, warmup_step, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
120
+ warmup_step += 1
121
+
122
+ loss = model(experts, caption)
123
+
124
+ optimizer.zero_grad()
125
+ accelerator.backward(loss)
126
+ optimizer.step()
127
+
128
+ train_loss += loss.item()
129
+ num_train_elems += 1
130
+
131
+ train_loss /= num_train_elems
132
+ accelerator.print(f"Epoch {epoch:03d} | loss: {train_loss:.4f} || Time: {(time.time() - start_time):.4f}")
133
+ accelerator.save_state(f'logging/pretrain_{args.exp_name}')
134
+ accelerator.save([epoch], f'logging/pretrain_{args.exp_name}/epoch.pt')
135
+
136
+ total_time = time.time() - start_time
137
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
138
+ accelerator.print('Training time {}'.format(total_time_str))
139
+
140
+
prismer/train_vqa.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, visit
5
+ # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
+
7
+ import argparse
8
+ import numpy as np
9
+ import random
10
+ import time
11
+ import datetime
12
+ import functools
13
+ import torch
14
+ try:
15
+ import ruamel_yaml as yaml
16
+ except ModuleNotFoundError:
17
+ import ruamel.yaml as yaml
18
+
19
+ from accelerate import Accelerator, FullyShardedDataParallelPlugin
20
+ from model.prismer_vqa import PrismerVQA
21
+ from model.modules.utils import interpolate_pos_embed
22
+ from dataset import create_dataset, create_loader
23
+ from utils import *
24
+ from tqdm import tqdm
25
+ import json
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--mode', default='')
29
+ parser.add_argument('--port', default='')
30
+
31
+ parser.add_argument('--config', default='configs/vqa.yaml')
32
+ parser.add_argument('--from_checkpoint', action='store_true')
33
+ parser.add_argument('--evaluate', action='store_true')
34
+ parser.add_argument('--exp_name', default='', type=str)
35
+ parser.add_argument('--shard_grad_op', action='store_true')
36
+ parser.add_argument('--full_shard', action='store_true')
37
+ parser.add_argument('--mixed_precision', default='fp16', type=str)
38
+ parser.add_argument('--seed', default=42, type=int)
39
+ args = parser.parse_args()
40
+
41
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
42
+ torch.manual_seed(args.seed)
43
+ np.random.seed(args.seed)
44
+ random.seed(args.seed)
45
+
46
+ train_dataset, test_dataset = create_dataset('vqa', config)
47
+ train_loader = create_loader(train_dataset, batch_size=config['batch_size_train'], num_workers=8, train=True)
48
+ test_loader = create_loader(test_dataset, batch_size=config['batch_size_test'], num_workers=8, train=False)
49
+
50
+ model = PrismerVQA(config)
51
+ tokenizer = model.tokenizer
52
+
53
+ if args.shard_grad_op: # Model Sharding: ZeRO 2
54
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
55
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
56
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
57
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
58
+ reduce_dtype=torch.float16,
59
+ buffer_dtype=torch.float16),
60
+ state_dict_type=StateDictType.FULL_STATE_DICT,
61
+ ignored_modules=model.ignored_modules)
62
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
63
+ model = accelerator.prepare(model)
64
+
65
+ elif args.full_shard: # Model Sharding: ZeRO 3
66
+ from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy, StateDictType
67
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
68
+ from model.modules.vit import ResidualAttentionBlock
69
+ from model.modules.resampler import PerceiverAttentionBlock
70
+ from model.modules.roberta import RobertaLayer
71
+ auto_wrap_policy = functools.partial(
72
+ transformer_auto_wrap_policy,
73
+ transformer_layer_cls={
74
+ ResidualAttentionBlock,
75
+ PerceiverAttentionBlock,
76
+ RobertaLayer
77
+ },
78
+ )
79
+ fsdp_plugin = FullyShardedDataParallelPlugin(sharding_strategy=ShardingStrategy.FULL_SHARD,
80
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
81
+ mixed_precision_policy=MixedPrecision(param_dtype=torch.float16,
82
+ reduce_dtype=torch.float16,
83
+ buffer_dtype=torch.float16),
84
+ state_dict_type=StateDictType.FULL_STATE_DICT,
85
+ auto_wrap_policy=auto_wrap_policy,
86
+ ignored_modules=model.ignored_modules)
87
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, fsdp_plugin=fsdp_plugin)
88
+ model = accelerator.prepare(model)
89
+ else:
90
+ accelerator = Accelerator(mixed_precision=args.mixed_precision)
91
+
92
+ # Reload saved states
93
+ if not args.from_checkpoint:
94
+ state_dict = torch.load(f'logging/pretrain_{args.exp_name}/pytorch_model.bin', map_location='cpu')
95
+ state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
96
+ len(model.expert_encoder.positional_embedding))
97
+ model.load_state_dict(state_dict)
98
+ start_epoch = 0
99
+ else:
100
+ state_dict = torch.load(f'logging/vqa_{args.exp_name}/pytorch_model.bin', map_location='cpu')
101
+ if os.path.exists(f'logging/vqa_{args.exp_name}/epoch.pt'):
102
+ start_epoch = torch.load(f'logging/vqa_{args.exp_name}/epoch.pt')[0] + 1
103
+ else:
104
+ start_epoch = 0
105
+ model.load_state_dict(state_dict)
106
+ accelerator.print(f'Start re-training from checkpoint with Epoch {start_epoch}')
107
+
108
+ optimizer = torch.optim.AdamW(params=filter(lambda p: p.requires_grad, model.parameters()),
109
+ lr=config['init_lr'], weight_decay=config['weight_decay'])
110
+
111
+ if args.shard_grad_op or args.full_shard:
112
+ optimizer, train_loader, test_loader = accelerator.prepare(optimizer, train_loader, test_loader)
113
+ else:
114
+ model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)
115
+
116
+ start_time = time.time()
117
+ if not args.evaluate:
118
+ for epoch in range(start_epoch, config['max_epoch']):
119
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
120
+
121
+ train_loss = 0
122
+ num_train_elems = 0
123
+ model.train()
124
+ for i, (experts, question, answer, weights) in enumerate(tqdm(train_loader)):
125
+ loss = model(experts, question, answer, train=True, weights=weights)
126
+ optimizer.zero_grad()
127
+ accelerator.backward(loss)
128
+ optimizer.step()
129
+
130
+ train_loss += loss.item()
131
+ num_train_elems += 1
132
+
133
+ train_loss /= num_train_elems
134
+ accelerator.print(f"Epoch {epoch:03d} | loss: {train_loss:.4f} || Time: {(time.time() - start_time):.4f}")
135
+ accelerator.save_state(f'logging/vqa_{args.exp_name}')
136
+ accelerator.save([epoch], f'logging/vqa_{args.exp_name}/epoch.pt')
137
+
138
+ model.eval()
139
+ if accelerator.is_main_process:
140
+ result = []
141
+
142
+ with torch.no_grad():
143
+ if config['inference'] == 'rank':
144
+ answer_list = test_loader.dataset.answer_list
145
+
146
+ for step, (experts, data_ids, question, question_id) in enumerate(tqdm(test_loader)):
147
+ if config['inference'] == 'generate':
148
+ answers = model(experts, question, train=False, inference='generate')
149
+
150
+ if accelerator.use_distributed:
151
+ answers = tokenizer(answers, max_length=15, padding='max_length', return_tensors='pt').input_ids
152
+ answers = answers.to(experts['rgb'].device)
153
+ data_ids, answers, question_id = accelerator.gather_for_metrics((data_ids, answers, question_id))
154
+
155
+ if accelerator.is_main_process:
156
+ for data_id, answer, ques_id in zip(data_ids, answers, question_id):
157
+ answer = tokenizer.decode(answer, skip_special_tokens=True)
158
+ result.append({"question_id": int(ques_id.item()), "answer": answer})
159
+
160
+ elif config['inference'] == 'rank':
161
+ answer_ids = model(experts, question, answer_list, train=False, inference='rank', k_test=config['k_test'])
162
+
163
+ if accelerator.use_distributed:
164
+ answer_ids, question_id = accelerator.gather_for_metrics((answer_ids, question_id))
165
+
166
+ if accelerator.is_main_process:
167
+ for ques_id, answer_id in zip(question_id, answer_ids):
168
+ result.append({"question_id": int(ques_id.item()), "answer": answer_list[answer_id]})
169
+
170
+
171
+ accelerator.wait_for_everyone()
172
+ if accelerator.is_main_process:
173
+ json.dump(result, open(f'/results/vqa_results_{args.exp_name}.json', 'w'))
174
+
175
+
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ accelerator.print('Training time {}'.format(total_time_str))
179
+
180
+
prismer_model.py CHANGED
@@ -20,22 +20,32 @@ from model.prismer_caption import PrismerCaption
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
- subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
24
-
 
25
  model_names = [
26
- # 'vqa_prismer_base',
27
- # 'vqa_prismer_large',
 
 
 
 
28
  'caption_prismer_base',
29
  'caption_prismer_large',
30
  ]
31
  for model_name in model_names:
32
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
33
  continue
34
- subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
 
 
35
 
36
 
37
  def build_deformable_conv() -> None:
38
- subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
 
 
 
39
 
40
 
41
  def run_experts(image_path: str) -> tuple[str | None, ...]:
@@ -46,18 +56,40 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
46
  out_path = image_dir / 'image.jpg'
47
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
48
 
49
- expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
 
 
 
 
 
 
 
50
  for expert_name in expert_names:
51
  env = os.environ.copy()
52
  if 'PYTHONPATH' in env:
53
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
54
  else:
55
  env['PYTHONPATH'] = submodule_dir.as_posix()
56
- subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
57
-
58
- keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
59
- results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
60
- return tuple(path.as_posix() if path.exists() else None for path in results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  class Model:
@@ -71,14 +103,28 @@ class Model:
71
  if exp_name == self.exp_name:
72
  return
73
  config = {
74
- 'dataset': 'demo',
75
- 'data_path': 'prismer/helpers',
76
- 'label_path': 'prismer/helpers/labels',
77
- 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
78
- 'image_resolution': 480,
79
- 'prismer_model': 'prismer_base',
80
- 'freeze': 'freeze_vision',
81
- 'prefix': 'A picture of',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  }
83
  model = PrismerCaption(config)
84
  state_dict = torch.load(
@@ -96,17 +142,27 @@ class Model:
96
  @torch.inference_mode()
97
  def run_caption_model(self, exp_name: str) -> str:
98
  self.set_model(exp_name)
 
99
  _, test_dataset = create_dataset('caption', self.config)
100
- test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
 
 
 
101
  experts, _ = next(iter(test_loader))
102
- captions = self.model(experts, train=False, prefix=self.config['prefix'])
103
- captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
 
 
 
 
 
104
  caption = captions.to(experts['rgb'].device)[0]
105
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
106
  caption = caption.capitalize() + '.'
107
  return caption
108
 
109
- def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
 
110
  out_paths = run_experts(image_path)
111
- # caption = self.run_caption_model(model_name)
112
- return None, *out_paths
 
20
 
21
  def download_models() -> None:
22
  if not pathlib.Path('prismer/experts/expert_weights/').exists():
23
+ subprocess.run(shlex.split(
24
+ 'python download_checkpoints.py --download_experts=True'),
25
+ cwd='prismer')
26
  model_names = [
27
+ 'vqa_prismer_base',
28
+ 'vqa_prismer_large',
29
+ 'vqa_prismerz_base',
30
+ 'vqa_prismerz_large',
31
+ 'caption_prismerz_base',
32
+ 'caption_prismerz_large',
33
  'caption_prismer_base',
34
  'caption_prismer_large',
35
  ]
36
  for model_name in model_names:
37
  if pathlib.Path(f'prismer/logging/{model_name}').exists():
38
  continue
39
+ subprocess.run(shlex.split(
40
+ f'python download_checkpoints.py --download_models={model_name}'),
41
+ cwd='prismer')
42
 
43
 
44
  def build_deformable_conv() -> None:
45
+ subprocess.run(
46
+ shlex.split('sh make.sh'),
47
+ cwd=
48
+ 'prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
49
 
50
 
51
  def run_experts(image_path: str) -> tuple[str | None, ...]:
 
56
  out_path = image_dir / 'image.jpg'
57
  cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
58
 
59
+ expert_names = [
60
+ 'depth',
61
+ 'edge',
62
+ 'normal',
63
+ 'objdet',
64
+ 'ocrdet',
65
+ 'segmentation',
66
+ ]
67
  for expert_name in expert_names:
68
  env = os.environ.copy()
69
  if 'PYTHONPATH' in env:
70
  env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
71
  else:
72
  env['PYTHONPATH'] = submodule_dir.as_posix()
73
+ subprocess.run(
74
+ shlex.split(f'python experts/generate_{expert_name}.py'),
75
+ cwd='prismer',
76
+ env=env,
77
+ check=True)
78
+
79
+ keys = [
80
+ 'depth',
81
+ 'edge',
82
+ 'normal',
83
+ 'seg_coco',
84
+ 'obj_detection',
85
+ 'ocr_detection',
86
+ ]
87
+ results = [
88
+ pathlib.Path('prismer/helpers/labels') / key /
89
+ 'helpers/images/image.png' for key in keys
90
+ ]
91
+ return tuple(path.as_posix() if path.exists() else None
92
+ for path in results)
93
 
94
 
95
  class Model:
 
103
  if exp_name == self.exp_name:
104
  return
105
  config = {
106
+ 'dataset':
107
+ 'demo',
108
+ 'data_path':
109
+ 'prismer/helpers',
110
+ 'label_path':
111
+ 'prismer/helpers/labels',
112
+ 'experts': [
113
+ 'depth',
114
+ 'normal',
115
+ 'seg_coco',
116
+ 'edge',
117
+ 'obj_detection',
118
+ 'ocr_detection',
119
+ ],
120
+ 'image_resolution':
121
+ 480,
122
+ 'prismer_model':
123
+ 'prismer_base',
124
+ 'freeze':
125
+ 'freeze_vision',
126
+ 'prefix':
127
+ 'A picture of',
128
  }
129
  model = PrismerCaption(config)
130
  state_dict = torch.load(
 
142
  @torch.inference_mode()
143
  def run_caption_model(self, exp_name: str) -> str:
144
  self.set_model(exp_name)
145
+
146
  _, test_dataset = create_dataset('caption', self.config)
147
+ test_loader = create_loader(test_dataset,
148
+ batch_size=1,
149
+ num_workers=4,
150
+ train=False)
151
  experts, _ = next(iter(test_loader))
152
+ captions = self.model(experts,
153
+ train=False,
154
+ prefix=self.config['prefix'])
155
+ captions = self.tokenizer(captions,
156
+ max_length=30,
157
+ padding='max_length',
158
+ return_tensors='pt').input_ids
159
  caption = captions.to(experts['rgb'].device)[0]
160
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
161
  caption = caption.capitalize() + '.'
162
  return caption
163
 
164
+ def run_caption(self, image_path: str,
165
+ model_name: str) -> tuple[str | None, ...]:
166
  out_paths = run_experts(image_path)
167
+ caption = self.run_caption_model(model_name)
168
+ return caption, *out_paths