@@ -35,6 +35,14 @@ type Resolver struct {
3535 IdentityGetter func () (view.Identity , []byte , error )
3636}
3737
38+ func (r * Resolver ) GetName () string { return r .Name }
39+
40+ func (r * Resolver ) GetId () view.Identity { return r .Id }
41+
42+ func (r * Resolver ) GetAddress (port driver.PortName ) string { return r .Addresses [port ] }
43+
44+ func (r * Resolver ) GetAddresses () map [driver.PortName ]string { return r .Addresses }
45+
3846func (r * Resolver ) GetIdentity () (view.Identity , error ) {
3947 if r .IdentityGetter != nil {
4048 id , _ , err := r .IdentityGetter ()
@@ -75,38 +83,36 @@ func NewService(binder Binder) (*Service, error) {
7583 return er , nil
7684}
7785
78- func (r * Service ) Endpoint (party view.Identity ) (map [driver.PortName ]string , error ) {
79- _ , e , _ , err := r .resolve (party )
80- return e , err
81- }
82-
83- func (r * Service ) Resolve (party view.Identity ) (string , view.Identity , map [driver.PortName ]string , []byte , error ) {
84- cursor , e , resolver , err := r .resolve (party )
86+ func (r * Service ) Resolve (party view.Identity ) (driver.Resolver , []byte , error ) {
87+ resolver , err := r .resolver (party )
8588 if err != nil {
86- return "" , nil , nil , nil , err
89+ return nil , nil , err
8790 }
88- return resolver . Name , cursor , e , r .pkiResolve (resolver ), nil
91+ return resolver , r .pkiResolve (resolver ), nil
8992}
9093
91- func (r * Service ) resolve (party view.Identity ) (view.Identity , map [driver.PortName ]string , * Resolver , error ) {
94+ func (r * Service ) GetResolver (party view.Identity ) (driver.Resolver , error ) {
95+ return r .resolver (party )
96+ }
9297
98+ func (r * Service ) resolver (party view.Identity ) (* Resolver , error ) {
9399 // We can skip this check, but in case the long term was passed directly, this is going to spare us a DB lookup
94- resolver , e , err := r .rootEndpoint (party )
100+ resolver , err := r .rootEndpoint (party )
95101 if err == nil {
96- return party , e , resolver , nil
102+ return resolver , nil
97103 }
98104 logger .Debugf ("resolving via binding for %s" , party )
99105 party , err = r .binder .GetLongTerm (party )
100106 if err != nil {
101- return nil , nil , nil , err
107+ return nil , err
102108 }
103109 logger .Debugf ("continue to [%s]" , party )
104- resolver , e , err = r .rootEndpoint (party )
110+ resolver , err = r .rootEndpoint (party )
105111 if err != nil {
106- return nil , nil , nil , err
112+ return nil , err
107113 }
108114
109- return party , e , resolver , nil
115+ return resolver , nil
110116}
111117
112118func (r * Service ) Bind (longTerm view.Identity , ephemeral view.Identity ) error {
@@ -138,14 +144,7 @@ func (r *Service) GetIdentity(endpoint string, pkID []byte) (view.Identity, erro
138144
139145 // search in the resolver list
140146 for _ , resolver := range r .resolvers {
141- resolverPKID := r .pkiResolve (resolver )
142-
143- if endpoint == resolver .Name ||
144- endpoint == resolver .Name + "." + resolver .Domain ||
145- collections .ContainsValue (resolver .Addresses , endpoint ) ||
146- slices .Contains (resolver .Aliases , endpoint ) ||
147- bytes .Equal (pkID , resolver .Id ) ||
148- bytes .Equal (pkID , resolverPKID ) {
147+ if r .matchesResolver (endpoint , pkID , resolver ) {
149148
150149 id , err := resolver .GetIdentity ()
151150 if err != nil {
@@ -160,6 +159,18 @@ func (r *Service) GetIdentity(endpoint string, pkID []byte) (view.Identity, erro
160159 return nil , errors .Errorf ("identity not found at [%s,%s]" , endpoint , view .Identity (pkID ))
161160}
162161
162+ func (r * Service ) matchesResolver (endpoint string , pkID []byte , resolver * Resolver ) bool {
163+ if len (endpoint ) > 0 && (endpoint == resolver .Name ||
164+ endpoint == resolver .Name + "." + resolver .Domain ||
165+ collections .ContainsValue (resolver .Addresses , endpoint ) ||
166+ slices .Contains (resolver .Aliases , endpoint )) {
167+ return true
168+ }
169+
170+ return len (pkID ) > 0 && (bytes .Equal (pkID , resolver .Id ) ||
171+ bytes .Equal (pkID , r .pkiResolve (resolver )))
172+ }
173+
163174func (r * Service ) AddResolver (name string , domain string , addresses map [string ]string , aliases []string , id []byte ) (view.Identity , error ) {
164175 if logger .IsEnabledFor (zapcore .DebugLevel ) {
165176 logger .Debugf ("adding resolver [%s,%s,%v,%v,%s]" , name , domain , addresses , aliases , view .Identity (id ).String ())
@@ -251,17 +262,17 @@ func (r *Service) ExtractPKI(id []byte) []byte {
251262 return nil
252263}
253264
254- func (r * Service ) rootEndpoint (party view.Identity ) (* Resolver , map [driver. PortName ] string , error ) {
265+ func (r * Service ) rootEndpoint (party view.Identity ) (* Resolver , error ) {
255266 r .resolversMutex .RLock ()
256267 defer r .resolversMutex .RUnlock ()
257268
258269 for _ , resolver := range r .resolvers {
259270 if bytes .Equal (resolver .Id , party ) {
260- return resolver , resolver . Addresses , nil
271+ return resolver , nil
261272 }
262273 }
263274
264- return nil , nil , errors .Errorf ("endpoint not found for identity %s" , party .UniqueID ())
275+ return nil , errors .Errorf ("endpoint not found for identity %s" , party .UniqueID ())
265276}
266277
267278var portNameMap = map [string ]driver.PortName {
0 commit comments