1+ import numpy as np
2+ import matplotlib .pyplot as plt
3+ from tensorflow import keras
4+ from utils .config import Config
5+
6+ config = Config ()
7+
8+ def generate_anchors (sizes = None , ratios = None ):
9+ if sizes is None :
10+ sizes = config .anchor_box_scales
11+
12+ if ratios is None :
13+ ratios = config .anchor_box_ratios
14+
15+ num_anchors = len (sizes ) * len (ratios )
16+
17+ anchors = np .zeros ((num_anchors , 4 ))
18+ # print(anchors)
19+ anchors [:, 2 :] = np .tile (sizes , (2 , len (ratios ))).T
20+
21+ for i in range (len (ratios )):
22+ anchors [3 * i :3 * i + 3 , 2 ] = anchors [3 * i :3 * i + 3 , 2 ]* ratios [i ][0 ]
23+ anchors [3 * i :3 * i + 3 , 3 ] = anchors [3 * i :3 * i + 3 , 3 ]* ratios [i ][1 ]
24+
25+
26+ anchors [:, 0 ::2 ] -= np .tile (anchors [:, 2 ] * 0.5 , (2 , 1 )).T
27+ anchors [:, 1 ::2 ] -= np .tile (anchors [:, 3 ] * 0.5 , (2 , 1 )).T
28+ # print(anchors)
29+ return anchors
30+
31+ def shift (shape , anchors , stride = config .rpn_stride ):
32+ # [0,1,2,3,4,5……37]
33+ # [0.5,1.5,2.5……37.5]
34+ # [8,24,……]
35+ shift_x = (np .arange (0 , shape [0 ], dtype = keras .backend .floatx ()) + 0.5 ) * stride
36+ shift_y = (np .arange (0 , shape [1 ], dtype = keras .backend .floatx ()) + 0.5 ) * stride
37+
38+ shift_x , shift_y = np .meshgrid (shift_x , shift_y )
39+
40+ shift_x = np .reshape (shift_x , [- 1 ])
41+ shift_y = np .reshape (shift_y , [- 1 ])
42+ # print(shift_x,shift_y)
43+ shifts = np .stack ([
44+ shift_x ,
45+ shift_y ,
46+ shift_x ,
47+ shift_y
48+ ], axis = 0 )
49+
50+ shifts = np .transpose (shifts )
51+ number_of_anchors = np .shape (anchors )[0 ]
52+
53+ k = np .shape (shifts )[0 ]
54+
55+ shifted_anchors = np .reshape (anchors , [1 , number_of_anchors , 4 ]) + np .array (np .reshape (shifts , [k , 1 , 4 ]), keras .backend .floatx ())
56+ shifted_anchors = np .reshape (shifted_anchors , [k * number_of_anchors , 4 ])
57+
58+
59+ fig = plt .figure ()
60+ ax = fig .add_subplot (111 )
61+ plt .ylim (- 300 ,900 )
62+ plt .xlim (- 300 ,900 )
63+ # plt.ylim(0,600)
64+ # plt.xlim(0,600)
65+ plt .scatter (shift_x ,shift_y )
66+ box_widths = shifted_anchors [:,2 ]- shifted_anchors [:,0 ]
67+ box_heights = shifted_anchors [:,3 ]- shifted_anchors [:,1 ]
68+
69+ initial = 0
70+ for i in [initial + 0 ,initial + 1 ,initial + 2 ,initial + 3 ,initial + 4 ,initial + 5 ,initial + 6 ,initial + 7 ,initial + 8 ]:
71+ rect = plt .Rectangle ([shifted_anchors [i , 0 ],shifted_anchors [i , 1 ]],box_widths [i ],box_heights [i ],color = "r" ,fill = False )
72+ ax .add_patch (rect )
73+ plt .show ()
74+
75+ return shifted_anchors
76+
77+ def get_anchors (shape ,width ,height ):
78+ anchors = generate_anchors ()
79+ network_anchors = shift (shape ,anchors )
80+ network_anchors [:,0 ] = network_anchors [:,0 ]/ width
81+ network_anchors [:,1 ] = network_anchors [:,1 ]/ height
82+ network_anchors [:,2 ] = network_anchors [:,2 ]/ width
83+ network_anchors [:,3 ] = network_anchors [:,3 ]/ height
84+ network_anchors = np .clip (network_anchors ,0 ,1 )
85+ print (network_anchors )
86+ return network_anchors
87+
88+ if __name__ == "__main__" :
89+ get_anchors ([38 ,38 ],600 ,600 )
0 commit comments