1
1
import React from 'react' ;
2
+ import ReactDom from 'react-dom' ;
2
3
import { PropTypes } from 'prop-types' ;
3
4
import debounce from 'lodash.debounce' ;
4
5
import throttle from 'lodash.throttle' ;
5
6
import isIntersectionObserverAvailable from '../utils/intersection-observer' ;
7
+ import getScrollElement from '../utils/get-scroll-element' ;
8
+
9
+ const getScrollX = ( ) => typeof window === 'undefined' ?
10
+ 0 : ( window . scrollX || window . pageXOffset ) ;
11
+ const getScrollY = ( ) => typeof window === 'undefined' ?
12
+ 0 : ( window . scrollY || window . pageYOffset ) ;
6
13
7
14
const trackWindowScroll = ( BaseComponent ) => {
8
15
class ScrollAwareComponent extends React . Component {
@@ -23,48 +30,76 @@ const trackWindowScroll = (BaseComponent) => {
23
30
24
31
this . state = {
25
32
scrollPosition : {
26
- x : ( typeof window === 'undefined' ?
27
- 0 :
28
- ( window . scrollX || window . pageXOffset )
29
- ) ,
30
- y : ( typeof window === 'undefined' ?
31
- 0 :
32
- ( window . scrollY || window . pageYOffset )
33
- ) ,
33
+ x : getScrollX ( ) ,
34
+ y : getScrollY ( ) ,
34
35
} ,
35
36
} ;
37
+
38
+ this . baseComponentRef = React . createRef ( ) ;
36
39
}
37
40
38
41
componentDidMount ( ) {
39
- if ( typeof window == 'undefined' || isIntersectionObserverAvailable ( ) ) {
42
+ this . addListeners ( ) ;
43
+ }
44
+
45
+ componentWillUnmount ( ) {
46
+ this . removeListeners ( ) ;
47
+ }
48
+
49
+ componentDidUpdate ( ) {
50
+ if ( typeof window === 'undefined' || isIntersectionObserverAvailable ( ) ) {
51
+ return ;
52
+ }
53
+
54
+ const scrollElement = getScrollElement (
55
+ ReactDom . findDOMNode ( this . baseComponentRef . current )
56
+ ) ;
57
+
58
+ if ( scrollElement !== this . scrollElement ) {
59
+ this . removeListeners ( ) ;
60
+ this . addListeners ( ) ;
61
+ }
62
+ }
63
+
64
+ addListeners ( ) {
65
+ if ( typeof window === 'undefined' || isIntersectionObserverAvailable ( ) ) {
40
66
return ;
41
67
}
42
- window . addEventListener ( 'scroll' , this . delayedScroll ) ;
68
+
69
+ this . scrollElement = getScrollElement (
70
+ ReactDom . findDOMNode ( this . baseComponentRef . current )
71
+ ) ;
72
+
73
+ this . scrollElement . addEventListener ( 'scroll' , this . delayedScroll ) ;
43
74
window . addEventListener ( 'resize' , this . delayedScroll ) ;
75
+
76
+ if ( this . scrollElement !== window ) {
77
+ window . addEventListener ( 'scroll' , this . delayedScroll ) ;
78
+ }
44
79
}
45
80
46
- componentWillUnmount ( ) {
81
+ removeListeners ( ) {
47
82
if ( typeof window == 'undefined' || isIntersectionObserverAvailable ( ) ) {
48
83
return ;
49
84
}
50
- window . removeEventListener ( 'scroll' , this . delayedScroll ) ;
85
+
86
+ this . scrollElement . removeEventListener ( 'scroll' , this . delayedScroll ) ;
51
87
window . removeEventListener ( 'resize' , this . delayedScroll ) ;
88
+
89
+ if ( this . scrollElement !== window ) {
90
+ window . removeEventListener ( 'scroll' , this . delayedScroll ) ;
91
+ }
52
92
}
53
93
54
94
onChangeScroll ( ) {
55
95
if ( isIntersectionObserverAvailable ( ) ) {
56
96
return ;
57
97
}
98
+
58
99
this . setState ( {
59
100
scrollPosition : {
60
- x : ( typeof window == 'undefined' ?
61
- 0 :
62
- ( window . scrollX || window . pageXOffset )
63
- ) ,
64
- y : ( typeof window === 'undefined' ?
65
- 0 :
66
- ( window . scrollY || window . pageYOffset )
67
- ) ,
101
+ x : getScrollX ( ) ,
102
+ y : getScrollY ( ) ,
68
103
} ,
69
104
} ) ;
70
105
}
@@ -76,6 +111,7 @@ const trackWindowScroll = (BaseComponent) => {
76
111
77
112
return (
78
113
< BaseComponent
114
+ ref = { this . baseComponentRef }
79
115
scrollPosition = { scrollPosition }
80
116
{ ...props } />
81
117
) ;
0 commit comments