@@ -6,86 +6,86 @@ import (
66 "github.com/aws/aws-sdk-go-v2/aws"
77 "github.com/aws/aws-sdk-go-v2/config"
88 "github.com/aws/aws-sdk-go-v2/service/ecs"
9+ "github.com/aws/aws-sdk-go-v2/service/servicediscovery"
910 core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
1011 endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
1112 discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
1213 endpointservice "github.com/envoyproxy/go-control-plane/envoy/service/endpoint/v3"
1314 "github.com/envoyproxy/go-control-plane/pkg/cache/types"
1415 "github.com/envoyproxy/go-control-plane/pkg/cache/v3"
15- "github.com/envoyproxy/go-control-plane/pkg/resource/v3"
1616 gocache "github.com/patrickmn/go-cache"
1717 log "github.com/sirupsen/logrus"
1818 "google.golang.org/grpc"
1919 "google.golang.org/grpc/reflection"
2020 "net"
2121 "os"
22+ "os/signal"
2223 "strconv"
24+ "strings"
25+ "syscall"
2326 "time"
2427)
2528
29+ var srv * server
30+
2631type server struct {
27- ecs * ecs.Client
28- cache * gocache.Cache
32+ ecs * ecs.Client
33+ servicediscovery * servicediscovery.Client
34+ cache * gocache.Cache
2935}
3036
3137func init () {
32-
33- // Log as JSON instead of the default ASCII formatter.
38+ cfg , _ := config . LoadDefaultConfig ( context . Background ())
39+ srv = & server { ecs : ecs . NewFromConfig ( cfg ), servicediscovery : servicediscovery . NewFromConfig ( cfg ), cache : gocache . New ( time . Second * 30 , time . Second * 30 )}
3440 log .SetFormatter (& log.TextFormatter {})
35-
36- // Output to stdout instead of the default stderr
37- // Can be any io.Writer, see below for File example
3841 log .SetOutput (os .Stdout )
39-
40- // Only log the warning severity or above.
4142 log .SetLevel (log .InfoLevel )
4243}
4344
4445func (* server ) receive (stream endpointservice.EndpointDiscoveryService_StreamEndpointsServer , reqChannel chan * discovery.DiscoveryRequest ) {
4546 for {
4647 req , err := stream .Recv ()
4748 if err != nil {
48- log .Error ( "Error while receiving message from stream" , err )
49+ log .Debug ( "error while receiving message from stream: " , err )
4950 return
5051 }
5152
5253 select {
5354 case reqChannel <- req :
5455 case <- stream .Context ().Done ():
55- log .Error ("Stream closed" )
56+ log .Debug ("Stream closed" )
5657 return
5758 }
5859 }
5960}
6061
6162func (s * server ) StreamEndpoints (stream endpointservice.EndpointDiscoveryService_StreamEndpointsServer ) error {
62- stop := make ( chan struct {})
63+
6364 reqChannel := make (chan * discovery.DiscoveryRequest , 1 )
6465 go s .receive (stream , reqChannel )
6566
6667 for {
6768 select {
6869 case req , ok := <- reqChannel :
6970 if ! ok {
70- log .Error ("Error receiving request" )
71- return errors .New ("Error receiving request" )
71+ log .Error ("error receiving request" )
72+ return errors .New ("error receiving request" )
7273 }
73- eds , cacheOk := s .cache .Get (req .ResourceNames [0 ])
74+ cacheResp , cacheOk := s .cache .Get (req .ResourceNames [0 ])
7475 if ! cacheOk {
75- eds = s .generateEDS (req .ResourceNames [0 ])
76- s .cache .Set (req .ResourceNames [0 ], eds , time .Minute * 1 )
76+ eds := s .generateEDS (req .ResourceNames [0 ])
77+ response := cache.RawResponse {Version : strconv .FormatInt (time .Now ().Unix (), 10 ),
78+ Resources : []types.ResourceWithTtl {{Resource : eds }},
79+ Request : req }
80+ cacheResp , _ = response .GetDiscoveryResponse ()
81+
82+ s .cache .Set (req .ResourceNames [0 ], cacheResp , time .Second * 30 )
7783 }
78- response := cache.RawResponse {Version : req .VersionInfo ,
79- Resources : []types.ResourceWithTtl {{Resource : eds .(* endpoint.ClusterLoadAssignment )}},
80- Request : & discovery.DiscoveryRequest {TypeUrl : resource .EndpointType }}
81- cacheResp , err := response .GetDiscoveryResponse ()
82- err = stream .Send (cacheResp )
84+ err := stream .Send (cacheResp .(* discovery.DiscoveryResponse ))
8385 if err != nil {
84- log .Error ("Error StreamingEndpoint " , err )
86+ log .Error ("StreamingEndpoint-Send " , err )
8587 return err
8688 }
87- case <- stop :
88- return nil
8989 }
9090 }
9191}
@@ -95,15 +95,37 @@ func (s *server) DeltaEndpoints(stream endpointservice.EndpointDiscoveryService_
9595 return nil
9696}
9797
98- func (* server ) FetchEndpoints (ctx context.Context , req * discovery.DiscoveryRequest ) (* discovery.DiscoveryResponse , error ) {
99- log .Info ("FetchEndpoints service not implemented" )
100- return nil , nil
98+ func (s * server ) FetchEndpoints (ctx context.Context , req * discovery.DiscoveryRequest ) (* discovery.DiscoveryResponse , error ) {
99+ var err error
100+ cacheResp , cacheOk := s .cache .Get (req .ResourceNames [0 ])
101+ if ! cacheOk {
102+ eds := s .generateEDS (req .ResourceNames [0 ])
103+ s .cache .Set (req .ResourceNames [0 ], eds , time .Second * 30 )
104+ response := cache.RawResponse {Version : strconv .FormatInt (time .Now ().Unix (), 10 ),
105+ Resources : []types.ResourceWithTtl {{Resource : eds }},
106+ Request : req }
107+ cacheResp , err = response .GetDiscoveryResponse ()
108+ s .cache .Set (req .ResourceNames [0 ], cacheResp , time .Minute * 1 )
109+ }
110+ return cacheResp .(* discovery.DiscoveryResponse ), err
101111}
102112
103113func (s * server ) generateEDS (cluster string ) * endpoint.ClusterLoadAssignment {
114+
104115 var lbEndpoints = make ([]* endpoint.LbEndpoint , 0 )
116+ var endpointsChan = make (chan * endpoint.LbEndpoint , 1 )
117+
118+ if strings .Contains (cluster , "srv-" ) {
119+ log .Info ("Generating new EDS values - Cloudmap" )
120+ go s .getServiceDiscoveryIps (endpointsChan , cluster )
121+ } else {
122+ log .Info ("Generating new EDS values - ECS" )
123+ go s .getTaskIps (endpointsChan , cluster )
124+ }
105125
106- s .getTaskIps (& lbEndpoints , cluster , nil )
126+ for i := range endpointsChan {
127+ lbEndpoints = append (lbEndpoints , i )
128+ }
107129
108130 ret := & endpoint.ClusterLoadAssignment {
109131 ClusterName : cluster ,
@@ -117,68 +139,123 @@ func (s *server) generateEDS(cluster string) *endpoint.ClusterLoadAssignment {
117139 return ret
118140}
119141
120- func (s * server ) getTaskIps (lbEndpoints * []* endpoint.LbEndpoint , cluster string , nextToken * string ) {
121- taskArns , err := s .ecs .ListTasks (context .Background (), & ecs.ListTasksInput {Cluster : aws .String (cluster ), NextToken : nextToken })
122- if err != nil {
123- log .Error ("Error listing AWS tasks " , err )
124- return
125- }
126- tasks , err := s .ecs .DescribeTasks (context .Background (), & ecs.DescribeTasksInput {
127- Tasks : taskArns .TaskArns , Cluster : aws .String (cluster ),
128- })
129- if err != nil {
130- log .Error ("Error Describing AWS tasks " , err )
131- return
132- }
133- port , err := strconv .Atoi (os .Getenv (cluster + "_port" ))
134- if err != nil {
135- port = 80
136- }
137- for _ , task := range tasks .Tasks {
138- for _ , attachment := range task .Attachments {
139- for _ , details := range attachment .Details {
140- if aws .ToString (details .Name ) == "privateIPv4Address" {
141- * lbEndpoints = append (* lbEndpoints , & endpoint.LbEndpoint {HostIdentifier : & endpoint.LbEndpoint_Endpoint {
142- Endpoint : & endpoint.Endpoint {
143- Address : & core.Address {
144- Address : & core.Address_SocketAddress {
145- SocketAddress : & core.SocketAddress {
146- Address : aws .ToString (details .Value ),
147- PortSpecifier : & core.SocketAddress_PortValue {
148- PortValue : uint32 (port ),
142+ func (s * server ) getTaskIps (lbEndpoints chan * endpoint.LbEndpoint , cluster string ) {
143+ listTasks := ecs .NewListTasksPaginator (s .ecs , & ecs.ListTasksInput {Cluster : aws .String (cluster )})
144+ for listTasks .HasMorePages () {
145+ taskArns , err := listTasks .NextPage (context .TODO ())
146+ if err != nil {
147+ log .Error ("Error listing AWS tasks " , err )
148+ return
149+ }
150+ tasks , err := s .ecs .DescribeTasks (context .Background (), & ecs.DescribeTasksInput {
151+ Tasks : taskArns .TaskArns , Cluster : aws .String (cluster ),
152+ })
153+ if err != nil {
154+ log .Error ("Error Describing AWS tasks " , err )
155+ return
156+ }
157+ port , err := strconv .Atoi (os .Getenv (cluster + "_port" ))
158+ if err != nil {
159+ port = 80
160+ }
161+ for _ , task := range tasks .Tasks {
162+ for _ , attachment := range task .Attachments {
163+ for _ , details := range attachment .Details {
164+ if aws .ToString (details .Name ) == "privateIPv4Address" {
165+ lbEndpoints <- & endpoint.LbEndpoint {HostIdentifier : & endpoint.LbEndpoint_Endpoint {
166+ Endpoint : & endpoint.Endpoint {
167+ Address : & core.Address {
168+ Address : & core.Address_SocketAddress {
169+ SocketAddress : & core.SocketAddress {
170+ Protocol : core .SocketAddress_TCP ,
171+ Address : aws .ToString (details .Value ),
172+ PortSpecifier : & core.SocketAddress_PortValue {
173+ PortValue : uint32 (port ),
174+ },
149175 },
150176 },
151177 },
152178 },
153179 },
154- },
155- })
180+ }
181+ }
156182 }
157183 }
158184 }
159185 }
160- if taskArns .NextToken != nil {
161- s .getTaskIps (lbEndpoints , cluster , taskArns .NextToken )
186+ close (lbEndpoints )
187+ }
188+
189+ func (s * server ) getServiceDiscoveryIps (lbEndpoints chan * endpoint.LbEndpoint , serviceId string ) {
190+ listInstances := servicediscovery .NewListInstancesPaginator (s .servicediscovery , & servicediscovery.ListInstancesInput {ServiceId : aws .String (serviceId )})
191+ for listInstances .HasMorePages () {
192+ instances , err := listInstances .NextPage (context .TODO ())
193+ if err != nil {
194+ log .Error (err )
195+ }
196+ for _ , instance := range instances .Instances {
197+ port , err2 := strconv .Atoi (os .Getenv (serviceId + "_port" ))
198+ if err2 != nil {
199+ port , err2 = strconv .Atoi (instance .Attributes ["AWS_INSTANCE_PORT" ])
200+ if err2 != nil {
201+ port = 80
202+ }
203+ }
204+ lbEndpoints <- & endpoint.LbEndpoint {HostIdentifier : & endpoint.LbEndpoint_Endpoint {
205+ Endpoint : & endpoint.Endpoint {
206+ Address : & core.Address {
207+ Address : & core.Address_SocketAddress {
208+ SocketAddress : & core.SocketAddress {
209+ Protocol : core .SocketAddress_TCP ,
210+ Address : instance .Attributes ["AWS_INSTANCE_IPV4" ],
211+ PortSpecifier : & core.SocketAddress_PortValue {
212+ PortValue : uint32 (port ),
213+ },
214+ },
215+ },
216+ },
217+ },
218+ },
219+ }
220+
221+ }
162222 }
223+ close (lbEndpoints )
163224}
225+
164226func main () {
227+ sigs := make (chan os.Signal , 1 )
228+ signal .Notify (sigs , syscall .SIGKILL , syscall .SIGINT , syscall .SIGTERM )
229+
165230 grpcServer := grpc .NewServer ()
231+
166232 edsListen := os .Getenv ("EDS_LISTEN" )
167233 if edsListen == "" {
168234 edsListen = "0.0.0.0:5678"
169235 }
236+
170237 lis , err := net .Listen ("tcp" , edsListen )
171238 if err != nil {
172239 log .Error (err )
240+ os .Exit (- 2 )
173241 }
174242
175- cfg , _ := config . LoadDefaultConfig ( context . Background ())
176- endpointservice .RegisterEndpointDiscoveryServiceServer (grpcServer , & server { ecs : ecs . NewFromConfig ( cfg ), cache : gocache . New ( time . Minute * 1 , time . Minute * 1 )} )
243+ go func () {
244+ endpointservice .RegisterEndpointDiscoveryServiceServer (grpcServer , srv )
177245
178- reflection .Register (grpcServer )
246+ reflection .Register (grpcServer )
179247
180- log .Infof ("management server listening on %d" , 5678 )
181- if err = grpcServer .Serve (lis ); err != nil {
182- log .Error (err )
183- }
248+ log .Infof ("management server listening on %s" , edsListen )
249+ if err = grpcServer .Serve (lis ); err != nil {
250+ log .Error (err )
251+ os .Exit (- 1 )
252+ }
253+ }()
254+
255+ sig := <- sigs
256+ log .Printf ("Caught Signal %v" , sig )
257+ go grpcServer .GracefulStop ()
258+ time .Sleep (time .Second * 5 )
259+ grpcServer .Stop ()
260+ os .Exit (0 )
184261}
0 commit comments