zhang-ziang commited on
Commit
ec4886b
·
1 Parent(s): 43a369c

requirements

Browse files
Files changed (2) hide show
  1. demo.py +0 -61
  2. requirements.txt +8 -0
demo.py DELETED
@@ -1,61 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import argparse
5
- import os
6
- import glob
7
-
8
- from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
9
- import pandas as pd
10
- from paths import *
11
- import numpy as np
12
- from vision_tower import DINOv2_MLP
13
- from PIL import Image
14
- save_path = './'
15
- device = 'cpu'
16
- dino = DINOv2_MLP(
17
- dino_mode = 'large',
18
- in_dim = 1024,
19
- out_dim = 360+180+60+2,
20
- evaluate = True,
21
- mask_dino = False,
22
- frozen_back = False
23
- ).to(device)
24
-
25
- dino.eval()
26
-
27
- dino.load_state_dict(torch.load(os.path.join(save_path, 'dino_weight.pt'), map_location='cpu'))
28
- val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
29
-
30
-
31
- def get_3angle(image_path):
32
-
33
- image = Image.open(image_path).convert('RGB')
34
- image_inputs = val_preprocess(images = image)
35
- image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
36
- with torch.no_grad():
37
- dino_pred = dino(image_inputs)
38
-
39
- gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
40
- gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
41
- gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
42
- angles = torch.zeros(3)
43
- angles[0] = gaus_ax_pred
44
- angles[1] = gaus_pl_pred - 90
45
- angles[2] = gaus_ro_pred - 30
46
-
47
- return angles
48
-
49
- with torch.no_grad():
50
- obj_angles = []
51
- img_paths = glob.glob(os.path.join('/home/aiops/wangzh/wangjialei/data_preprocess/meta/sa_10099.jpg'))
52
- img_paths.sort()
53
- for image_path in img_paths:
54
- # image_path = f'/home/aiops/wangzh/zza/Objaverse_render_extract/coco/demo_image/3D/{i}.png'
55
- image_name = image_path.split('/')[-1]
56
- print(image_name)
57
- angles = get_3angle(image_path)
58
- obj_angles.append(angles)
59
- # print(f'cat/{i}.png', angles)
60
- obj_angles = torch.stack(obj_angles, dim=0)
61
- print('wild', obj_angles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.1
2
+ transformer==4.38.2
3
+ matplotlib
4
+ pillow==10.2.0
5
+ huggingface-hub==0.26.5
6
+ gradio==5.9.0
7
+ numpy==1.26.4
8
+