Simonlob commited on
Commit
94e78da
1 Parent(s): e53bad3

Rename matcha/utils/monotonic_align/core.pyx to matcha/utils/monotonic_align/core.py

Browse files
matcha/utils/monotonic_align/core.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def maximum_path_each(path, value, t_x, t_y, max_neg_val):
4
+ for y in range(t_y):
5
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
6
+ if x == y:
7
+ v_cur = max_neg_val
8
+ else:
9
+ v_cur = value[x, y - 1]
10
+
11
+ if x == 0:
12
+ if y == 0:
13
+ v_prev = 0.
14
+ else:
15
+ v_prev = max_neg_val
16
+ else:
17
+ v_prev = value[x - 1, y - 1]
18
+
19
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
20
+
21
+ index = t_x - 1
22
+ for y in range(t_y - 1, -1, -1):
23
+ path[index, y] = 1
24
+ if index != 0 and (index == y or value[index, y - 1] < value[index - 1, y - 1]):
25
+ index -= 1
26
+
27
+ def maximum_path_c(paths, values, t_xs, t_ys, max_neg_val=-1e9):
28
+ b = values.shape[0]
29
+ for i in range(b):
30
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
matcha/utils/monotonic_align/core.pyx DELETED
@@ -1,47 +0,0 @@
1
- import numpy as np
2
-
3
- cimport cython
4
- cimport numpy as np
5
-
6
- from cython.parallel import prange
7
-
8
-
9
- @cython.boundscheck(False)
10
- @cython.wraparound(False)
11
- cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
12
- cdef int x
13
- cdef int y
14
- cdef float v_prev
15
- cdef float v_cur
16
- cdef float tmp
17
- cdef int index = t_x - 1
18
-
19
- for y in range(t_y):
20
- for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
21
- if x == y:
22
- v_cur = max_neg_val
23
- else:
24
- v_cur = value[x, y-1]
25
- if x == 0:
26
- if y == 0:
27
- v_prev = 0.
28
- else:
29
- v_prev = max_neg_val
30
- else:
31
- v_prev = value[x-1, y-1]
32
- value[x, y] = max(v_cur, v_prev) + value[x, y]
33
-
34
- for y in range(t_y - 1, -1, -1):
35
- path[index, y] = 1
36
- if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
37
- index = index - 1
38
-
39
-
40
- @cython.boundscheck(False)
41
- @cython.wraparound(False)
42
- cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
43
- cdef int b = values.shape[0]
44
-
45
- cdef int i
46
- for i in prange(b, nogil=True):
47
- maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)