99 "github.com/hashicorp/yamux"
1010 "github.com/sirupsen/logrus"
1111 "github.com/skycoin/skywire/pkg/skywire-utilities/pkg/netutil"
12+ "github.com/xtaci/smux"
1213
1314 "github.com/skycoin/dmsg/internal/servermetrics"
1415 "github.com/skycoin/dmsg/pkg/noise"
@@ -44,30 +45,54 @@ func (ss *ServerSession) Close() error {
4445func (ss * ServerSession ) Serve () {
4546 ss .m .RecordSession (servermetrics .DeltaConnect ) // record successful connection
4647 defer ss .m .RecordSession (servermetrics .DeltaDisconnect ) // record disconnection
47-
48- for {
49- yStr , err := ss .ys .AcceptStream ()
50- if err != nil {
51- switch err {
52- case yamux .ErrSessionShutdown , io .EOF :
53- ss .log .WithError (err ).Info ("Stopping session..." )
54- default :
55- ss .log .WithError (err ).Warn ("Failed to accept stream, stopping session..." )
48+ if ss .sm .smux != nil {
49+ for {
50+ sStr , err := ss .sm .smux .AcceptStream ()
51+ if err != nil {
52+ switch err {
53+ case io .EOF :
54+ ss .log .WithError (err ).Info ("Stopping session..." )
55+ default :
56+ ss .log .WithError (err ).Warn ("Failed to accept stream, stopping session..." )
57+ }
58+ return
5659 }
57- return
60+
61+ log := ss .log .WithField ("smux_id" , sStr .ID ())
62+ log .Info ("Initiating stream." )
63+
64+ go func (sStr * smux.Stream ) {
65+ err := ss .serveStream (log , sStr , ss .sm .addr )
66+ log .WithError (err ).Info ("Stopped stream." )
67+ }(sStr )
5868 }
69+ } else {
70+ for {
71+ yStr , err := ss .sm .yamux .AcceptStream ()
72+ if err != nil {
73+ switch err {
74+ case yamux .ErrSessionShutdown , io .EOF :
75+ ss .log .WithError (err ).Info ("Stopping session..." )
76+ default :
77+ ss .log .WithError (err ).Warn ("Failed to accept stream, stopping session..." )
78+ }
79+ return
80+ }
5981
60- log := ss .log .WithField ("yamux_id" , yStr .StreamID ())
61- log .Info ("Initiating stream." )
82+ log := ss .log .WithField ("yamux_id" , yStr .StreamID ())
83+ log .Info ("Initiating stream." )
6284
63- go func (yStr * yamux.Stream ) {
64- err := ss .serveStream (log , yStr )
65- log .WithError (err ).Info ("Stopped stream." )
66- }(yStr )
85+ go func (yStr * yamux.Stream ) {
86+ err := ss .serveStream (log , yStr , ss .sm .addr )
87+ log .WithError (err ).Info ("Stopped stream." )
88+ }(yStr )
89+ }
6790 }
6891}
6992
70- func (ss * ServerSession ) serveStream (log logrus.FieldLogger , yStr * yamux.Stream ) error {
93+ // struct
94+
95+ func (ss * ServerSession ) serveStream (log logrus.FieldLogger , yStr io.ReadWriteCloser , addr net.Addr ) error {
7196 readRequest := func () (StreamRequest , error ) {
7297 obj , err := ss .readObject (yStr )
7398 if err != nil {
@@ -102,7 +127,7 @@ func (ss *ServerSession) serveStream(log logrus.FieldLogger, yStr *yamux.Stream)
102127 if req .IPinfo && req .DstAddr .PK == ss .entity .LocalPK () {
103128 log .Debug ("Received IP stream request." )
104129
105- ip , err := addrToIP (yStr . RemoteAddr () )
130+ ip , err := addrToIP (addr )
106131 if err != nil {
107132 ss .m .RecordStream (servermetrics .DeltaFailed ) // record failed stream
108133 return err
@@ -164,22 +189,27 @@ func addrToIP(addr net.Addr) (net.IP, error) {
164189 }
165190}
166191
167- func (ss * ServerSession ) forwardRequest (req StreamRequest ) (yStr * yamux. Stream , respObj SignedObject , err error ) {
192+ func (ss * ServerSession ) forwardRequest (req StreamRequest ) (mStr io. ReadWriteCloser , respObj SignedObject , err error ) {
168193 defer func () {
169- if err != nil && yStr != nil {
194+ if err != nil && mStr != nil {
170195 ss .log .
171- WithError (yStr .Close ()).
196+ WithError (mStr .Close ()).
172197 Debugf ("After forwardRequest failed, the yamux stream is closed." )
173198 }
174199 }()
175-
176- if yStr , err = ss .ys .OpenStream (); err != nil {
177- return nil , nil , err
200+ if ss .sm .smux != nil {
201+ if mStr , err = ss .sm .smux .OpenStream (); err != nil {
202+ return nil , nil , err
203+ }
204+ } else {
205+ if mStr , err = ss .sm .yamux .OpenStream (); err != nil {
206+ return nil , nil , err
207+ }
178208 }
179- if err = ss .writeObject (yStr , req .raw ); err != nil {
209+ if err = ss .writeObject (mStr , req .raw ); err != nil {
180210 return nil , nil , err
181211 }
182- if respObj , err = ss .readObject (yStr ); err != nil {
212+ if respObj , err = ss .readObject (mStr ); err != nil {
183213 return nil , nil , err
184214 }
185215 var resp StreamResponse
@@ -189,5 +219,5 @@ func (ss *ServerSession) forwardRequest(req StreamRequest) (yStr *yamux.Stream,
189219 if err = resp .Verify (req ); err != nil {
190220 return nil , nil , err
191221 }
192- return yStr , respObj , nil
222+ return mStr , respObj , nil
193223}
0 commit comments