ghostsInTheMachine
commited on
Commit
•
73b0806
1
Parent(s):
d5d8098
Update infer.py
Browse files
infer.py
CHANGED
@@ -47,14 +47,16 @@ def infer_pipe(pipe, images_batch, task_name, seed, device):
|
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
# Convert list of images to tensor
|
50 |
-
images = [np.array(img.convert('RGB')).astype(np.
|
51 |
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
|
52 |
test_images = test_images / 127.5 - 1.0
|
53 |
-
test_images = test_images.to(device)
|
54 |
|
55 |
-
task_emb
|
|
|
|
|
56 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
57 |
-
task_emb = task_emb.repeat(
|
58 |
|
59 |
# Run inference
|
60 |
preds = pipe(
|
|
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
# Convert list of images to tensor
|
50 |
+
images = [np.array(img.convert('RGB')).astype(np.float32) for img in images_batch]
|
51 |
test_images = torch.stack([torch.tensor(img).permute(2, 0, 1) for img in images])
|
52 |
test_images = test_images / 127.5 - 1.0
|
53 |
+
test_images = test_images.to(device).type(torch.float16)
|
54 |
|
55 |
+
# Ensure task_emb matches expected dimensions
|
56 |
+
batch_size = test_images.shape[0]
|
57 |
+
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
58 |
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
|
59 |
+
task_emb = task_emb.repeat(batch_size, 1)
|
60 |
|
61 |
# Run inference
|
62 |
preds = pipe(
|