@@ -3,7 +3,6 @@ package registration
33import (
44 "bytes"
55 "encoding/json"
6- "errors"
76 "fmt"
87 "io/ioutil"
98 "net/http"
@@ -21,8 +20,9 @@ const (
2120 BZeroConfigStorage = "BZeroConfig"
2221 BZeroRegErrorExitCode = 234
2322
24- registrationEndpoint = "api/v2/targets/ssm/register"
25- prodServiceUrl = "https://cloud.bastionzero.com/" // default
23+ registrationEndpoint = "targets/ssm/register"
24+ prodServiceUrl = "https://cloud.bastionzero.com/" // default
25+ getConnectionServiceEndpoint = "api/v2/connection-service/url"
2626)
2727
2828// This is the data sent to the Reg API
@@ -43,6 +43,10 @@ type BZeroRegResponse struct {
4343 OrgProvider string `json:"externalOrganizationProvider"`
4444}
4545
46+ type GetConnectionServiceResponse struct {
47+ ConnectionServiceUrl string `json:"connectionServiceUrl"`
48+ }
49+
4650// Attempts to register as many times as is acceptable
4751func Register (log logger.T , apiKey string , envName string , envId string , targetName string , serviceUrl string ) (BZeroRegResponse , error ) {
4852 var response BZeroRegResponse
@@ -60,19 +64,8 @@ func Register(log logger.T, apiKey string, envName string, envId string, targetN
6064 EnvName : envName ,
6165 }
6266
63- // Build Registration Endpoint
64-
65- if serviceUrl == "" {
66- serviceUrl = prodServiceUrl
67- }
68- u , err := url .Parse (serviceUrl )
69- if err != nil {
70- return response , fmt .Errorf ("could not parse service url: %s error: %s" , serviceUrl , err )
71- }
72- u .Path = path .Join (u .Path , registrationEndpoint )
73-
7467 // Register with BastionZero
75- resp , err := post (log , regInfo , u . String () )
68+ resp , err := sendRegisterRequest (log , regInfo , serviceUrl )
7669 if err != nil {
7770 return response , err
7871 }
@@ -93,7 +86,75 @@ func Register(log logger.T, apiKey string, envName string, envId string, targetN
9386 return response , nil
9487}
9588
96- func post (log logger.T , regInfo BZeroRegRequest , regUrl string ) (* http.Response , error ) {
89+ func sendRegisterRequest (log logger.T , regInfo BZeroRegRequest , serviceUrl string ) (* http.Response , error ) {
90+ // Declare our variables
91+ var response * http.Response
92+
93+ // Marshal the regInfo data so we don't do it every time
94+ regInfoBytes , err := json .Marshal (regInfo )
95+ if err != nil {
96+ return response , fmt .Errorf ("could not marshal registration request" )
97+ }
98+
99+ // Build Registration Endpoint
100+ if serviceUrl == "" {
101+ serviceUrl = prodServiceUrl
102+ }
103+
104+ log .Infof ("Using service url %s" , serviceUrl )
105+
106+ // Get connection service url from bastion
107+ connectionServiceUrl , connectionServiceUrlErr := getConnectionServiceUrlFromServiceUrl (log , serviceUrl )
108+ if connectionServiceUrlErr != nil {
109+ return & http.Response {}, connectionServiceUrlErr
110+ }
111+
112+ u , err := url .Parse (connectionServiceUrl )
113+ if err != nil {
114+ return response , fmt .Errorf ("could not parse connection service url: %s error: %s" , connectionServiceUrl , err )
115+ }
116+ u .Path = path .Join (u .Path , registrationEndpoint )
117+
118+ log .Infof ("Registration Request Body: %s" , string (regInfoBytes ))
119+ req , err := http .NewRequest ("POST" , u .String (), bytes .NewBuffer (regInfoBytes ))
120+ if err != nil {
121+ return response , err
122+ }
123+
124+ resp , err := sendRequestWithRetry (log , req )
125+ if err != nil {
126+ return resp , err
127+ }
128+
129+ return resp , nil
130+ }
131+
132+ func missingResponseFields (resp BZeroRegResponse ) ([]string , bool ) {
133+ // Print out a specific message if missing registration data
134+ missing := []string {}
135+ if resp .ActivationId == "" {
136+ missing = append (missing , "Activation ID" )
137+ }
138+ if resp .ActivationCode == "" {
139+ missing = append (missing , "Activation Code" )
140+ }
141+ if resp .ActivationRegion == "" {
142+ missing = append (missing , "Activation Region" )
143+ }
144+ if resp .SSMTargetId == "" {
145+ missing = append (missing , "SSM Target ID" )
146+ }
147+ if resp .OrgID == "" {
148+ missing = append (missing , "Organization ID" )
149+ }
150+ if resp .OrgProvider == "" {
151+ missing = append (missing , "Organization Provider" )
152+ }
153+
154+ return missing , len (missing ) == 0
155+ }
156+
157+ func sendRequestWithRetry (log logger.T , req * http.Request ) (* http.Response , error ) {
97158 // Default params
98159 // Ref: https://github.com/cenkalti/backoff/blob/a78d3804c2c84f0a3178648138442c9b07665bda/exponential.go#L76
99160 // DefaultInitialInterval = 500 * time.Millisecond
@@ -110,33 +171,19 @@ func post(log logger.T, regInfo BZeroRegRequest, regUrl string) (*http.Response,
110171 // Make our ticker
111172 ticker := backoff .NewTicker (backoffParams )
112173
113- // Declare our variables
114- var response * http.Response
115-
116- // Marshal the regInfo data so we don't do it every time
117- regInfoBytes , err := json .Marshal (regInfo )
118- if err != nil {
119- return response , fmt .Errorf ("could not marshal registration request" )
120- }
121-
122- // Keep looping through our ticker, waiting for it to tell us when to retry
123174 for range ticker .C {
124- // Make our Client
175+
176+ // Make our http Client
125177 var httpClient = & http.Client {
126178 Timeout : time .Second * 10 ,
127179 }
128180
129- // Build request
130- req , err := http .NewRequest ("POST" , regUrl , bytes .NewBuffer (regInfoBytes ))
131- if err != nil {
132- return response , fmt .Errorf ("Error creating new http request: %v" , err )
133- }
134-
135181 // Headers
136182 req .Header .Add ("Accept" , "application/json" )
137183 req .Header .Add ("Content-Type" , "application/json" )
138184
139- response , err = httpClient .Do (req )
185+ log .Infof ("Sending request to: %s" , req .URL )
186+ response , err := httpClient .Do (req )
140187
141188 // If the status code is unauthorized, do not attempt to retry
142189 if response .StatusCode == http .StatusInternalServerError ||
@@ -146,8 +193,6 @@ func post(log logger.T, regInfo BZeroRegRequest, regUrl string) (*http.Response,
146193 response .StatusCode == http .StatusUnsupportedMediaType {
147194
148195 ticker .Stop ()
149- log .Infof ("Registration Endpoint: %s" , regUrl )
150- log .Infof ("Registration Request Body: %s" , string (regInfoBytes ))
151196 return response , fmt .Errorf ("received response code: %d, not retrying" , response .StatusCode )
152197 }
153198
@@ -159,30 +204,37 @@ func post(log logger.T, regInfo BZeroRegRequest, regUrl string) (*http.Response,
159204 return response , err
160205 }
161206
162- return nil , errors . New ( "unable to make post request" )
207+ return nil , fmt . Errorf ( "Failed to successfully make request to: %s" , req . URL )
163208}
164209
165- func missingResponseFields (resp BZeroRegResponse ) ([]string , bool ) {
166- // Print out a specific message if missing registration data
167- missing := []string {}
168- if resp .ActivationId == "" {
169- missing = append (missing , "Activation ID" )
170- }
171- if resp .ActivationCode == "" {
172- missing = append (missing , "Activation Code" )
210+ func getConnectionServiceUrlFromServiceUrl (log logger.T , serviceUrl string ) (string , error ) {
211+ // Make request to bastion to get connection service url
212+ u , err := url .Parse (serviceUrl )
213+ if err != nil {
214+ return "" , fmt .Errorf ("could not parse service url: %s error: %s" , serviceUrl , err )
173215 }
174- if resp .ActivationRegion == "" {
175- missing = append (missing , "Activation Region" )
216+ u .Path = path .Join (u .Path , getConnectionServiceEndpoint )
217+
218+ req , err := http .NewRequest ("GET" , u .String (), nil )
219+ if err != nil {
220+ return "" , err
176221 }
177- if resp .SSMTargetId == "" {
178- missing = append (missing , "SSM Target ID" )
222+
223+ resp , err := sendRequestWithRetry (log , req )
224+ if err != nil {
225+ return "" , err
179226 }
180- if resp .OrgID == "" {
181- missing = append (missing , "Organization ID" )
227+
228+ // Unmarshal the response
229+ respBytes , readAllErr := ioutil .ReadAll (resp .Body )
230+ if readAllErr != nil {
231+ return "" , fmt .Errorf ("error reading body on get connection service url request: %v" , readAllErr )
182232 }
183- if resp .OrgProvider == "" {
184- missing = append (missing , "Organization Provider" )
233+
234+ var getConnectionServiceResponse GetConnectionServiceResponse
235+ if err := json .Unmarshal (respBytes , & getConnectionServiceResponse ); err != nil {
236+ return "" , fmt .Errorf ("malformed getConnectionService response: %s" , err )
185237 }
186238
187- return missing , len ( missing ) == 0
239+ return getConnectionServiceResponse . ConnectionServiceUrl , nil
188240}
0 commit comments