@@ -71,7 +71,10 @@ func (client *Client) Download(ctx context.Context, limit *pb.OrderLimit, pieceP
7171
7272 ctx , cancel := context2 .WithCustomCancel (ctx )
7373
74- var underlyingStream downloadStream
74+ var (
75+ underlyingStream downloadStream
76+ timedOut bool
77+ )
7578 sync2 .WithTimeout (client .config .MessageTimeout , func () {
7679 if client .replaySafe != nil {
7780 underlyingStream , err = client .replaySafe .Download (ctx )
@@ -80,11 +83,16 @@ func (client *Client) Download(ctx context.Context, limit *pb.OrderLimit, pieceP
8083 }
8184 }, func () {
8285 cancel (errMessageTimeout )
86+ timedOut = true
8387 })
8488 if err != nil {
8589 cancel (context .Canceled )
8690 return nil , err
8791 }
92+ if timedOut {
93+ return nil , errMessageTimeout
94+ }
95+
8896 stream := & timedDownloadStream {
8997 timeout : client .config .MessageTimeout ,
9098 stream : underlyingStream ,
@@ -386,25 +394,46 @@ func (stream *timedDownloadStream) cancelTimeout() {
386394 stream .cancel (errMessageTimeout )
387395}
388396
389- func (stream * timedDownloadStream ) Close () (err error ) {
390- sync2 .WithTimeout (stream .timeout , func () {
391- err = stream .stream .Close ()
392- }, stream .cancelTimeout )
393- return CloseError .Wrap (err )
397+ func (stream * timedDownloadStream ) Close () error {
398+ return stream .withTimeout (stream .stream .Close )
394399}
395400
396- func (stream * timedDownloadStream ) Send (req * pb.PieceDownloadRequest ) (err error ) {
397- sync2 .WithTimeout (stream .timeout , func () {
398- err = stream .stream .Send (req )
399- }, stream .cancelTimeout )
400- return err
401+ func (stream * timedDownloadStream ) Send (req * pb.PieceDownloadRequest ) error {
402+ return stream .withTimeout (func () error {
403+ return stream .stream .Send (req )
404+ })
401405}
402406
403407func (stream * timedDownloadStream ) Recv () (resp * pb.PieceDownloadResponse , err error ) {
408+ err = stream .withTimeout (func () error {
409+ var recvErr error
410+ resp , recvErr = stream .stream .Recv ()
411+ return recvErr
412+ })
413+ if err != nil {
414+ return nil , err
415+ }
416+ return resp , nil
417+ }
418+
419+ func (stream * timedDownloadStream ) withTimeout (fn func () error ) error {
420+ var (
421+ err error
422+ timedOut bool
423+ )
404424 sync2 .WithTimeout (stream .timeout , func () {
405- resp , err = stream .stream .Recv ()
406- }, stream .cancelTimeout )
407- return resp , err
425+ err = fn ()
426+ }, func () {
427+ stream .cancelTimeout ()
428+ timedOut = true
429+ })
430+ if err != nil {
431+ return CloseError .Wrap (err )
432+ }
433+ if timedOut {
434+ return CloseError .Wrap (errMessageTimeout )
435+ }
436+ return nil
408437}
409438
410439// syncError synchronizes access to an error and keeps
0 commit comments