@@ -53,10 +53,21 @@ public abstract class TermScorer
5353 protected double corpusIdf ;
5454 /** Focus-side accumulator. */
5555 protected double acc ;
56+ /** Number of documents observed on either side for current term */
57+ protected int termDocs ;
58+
59+ /** focus usage */
60+ protected boolean hasFocus ;
61+ /** Sum of focus-side term frequencies for the current term. */
62+ protected long focusTermFreq ;
63+ /** Number of focus documents containing the current term. */
64+ protected int focusTermDocs ;
65+ /** Sum of other-side term frequencies for the current term. */
66+ protected long otherTermFreq ;
67+ /** Number of focus documents containing the current term. */
68+ protected int otherTermDocs ;
5669 /** Rest-side accumulator. Stays zero for non-contrastive use. */
5770 protected double otherAcc ;
58- /** Number of documents observed on either side. */
59- protected int collectCount ;
6071
6172 /**
6273 * Set corpus-level statistics. Called once before any {@link #term}.
@@ -66,6 +77,7 @@ public abstract class TermScorer
6677 */
6778 public final void corpus (final long corpusTokens , final int corpusDocs )
6879 {
80+ this .hasFocus = false ;
6981 this .corpusTokens = corpusTokens ;
7082 this .corpusDocs = corpusDocs ;
7183 this .docTokensAvg = (corpusDocs > 0 )
@@ -82,6 +94,7 @@ public final void corpus(final long corpusTokens, final int corpusDocs)
8294 */
8395 public void focus (final long focusTokens , final int focusDocs )
8496 {
97+ this .hasFocus = true ;
8598 this .focusTokens = focusTokens ;
8699 this .focusDocs = focusDocs ;
87100 }
@@ -97,15 +110,32 @@ public void termStart(final long corpusTermFreq, final int corpusTermDocs)
97110 {
98111 this .corpusTermFreq = corpusTermFreq ;
99112 this .corpusTermDocs = corpusTermDocs ;
113+ this .focusTermFreq = 0L ;
114+ this .focusTermDocs = 0 ;
115+ this .otherTermFreq = 0L ;
116+ this .otherTermDocs = 0 ;
100117 this .corpusTermRate = (corpusTokens > 0L )
101118 ? (double ) corpusTermFreq / (double ) corpusTokens
102119 : 0d ;
103120 this .corpusIdf = 0d ;
104121 this .acc = 0d ;
105122 this .otherAcc = 0d ;
106- this .collectCount = 0 ;
123+ this .termDocs = 0 ;
107124 }
108125
126+ /**
127+ * Compute the local score for one document and fold it into
128+ * {@link #acc}.
129+ *
130+ * @param docTermFreq occurrences of the term in the document
131+ * @param docTokens total token count of the document
132+ * @return local per-document score
133+ */
134+ public double termDocAdd (final long docTermFreq , final long docTokens ) {
135+ return termDocAdd (docTermFreq , docTokens , true );
136+ }
137+
138+
109139 /**
110140 * Compute the local score for one document and fold it into
111141 * {@link #acc} (focus) or {@link #otherAcc} (rest).
@@ -115,7 +145,17 @@ public void termStart(final long corpusTermFreq, final int corpusTermDocs)
115145 * @param inFocus {@code true} if the document belongs to the focus subset
116146 * @return local per-document score
117147 */
118- public abstract double termDocAdd (final long docTermFreq , final long docTokens , final boolean inFocus );
148+ public double termDocAdd (final long docTermFreq , final long docTokens , final boolean inFocus ) {
149+ termDocs ++;
150+ if (inFocus ) {
151+ focusTermFreq += docTermFreq ;
152+ focusTermDocs ++;
153+ } else {
154+ otherTermFreq += docTermFreq ;
155+ otherTermDocs ++;
156+ }
157+ return 0d ;
158+ }
119159
120160 /**
121161 * Returns the aggregated score for the current term.
@@ -156,6 +196,7 @@ public static class G extends TermScorer
156196 @ Override
157197 public double termDocAdd (final long docTermFreq , final long docTokens , final boolean inFocus )
158198 {
199+ super .termDocAdd (docTermFreq , docTokens , inFocus );
159200 if (docTokens <= 0L || corpusTermRate <= 0d || docTermFreq <= 0L )
160201 return 0d ;
161202 final double expected = corpusTermRate * (double ) docTokens ;
@@ -166,7 +207,6 @@ public double termDocAdd(final long docTermFreq, final long docTokens, final boo
166207 acc += local ;
167208 else
168209 otherAcc += local ;
169- collectCount ++;
170210 return local ;
171211 }
172212 }
@@ -179,6 +219,7 @@ public static class Jaccard extends TermScorer
179219 @ Override
180220 public double termDocAdd (final long docTermFreq , final long docTokens , final boolean inFocus )
181221 {
222+ super .termDocAdd (docTermFreq , docTokens , inFocus );
182223 if (docTermFreq <= 0L || docTokens <= 0L || corpusTermFreq <= 0L )
183224 return 0d ;
184225 final long union = docTokens + corpusTermFreq - docTermFreq ;
@@ -189,7 +230,6 @@ public double termDocAdd(final long docTermFreq, final long docTokens, final boo
189230 acc += local ;
190231 else
191232 otherAcc += local ;
192- collectCount ++;
193233 return local ;
194234 }
195235 }
@@ -220,10 +260,6 @@ public static class BM25 extends TermScorer
220260 protected double b = 0.75d ;
221261 /** Exponent applied to raw IDF. */
222262 protected final double idfExp ;
223- /** Sum of focus-side term frequencies for the current term. */
224- protected long focusTermFreqAcc ;
225- /** Number of focus documents containing the current term. */
226- protected int focusTermDocs ;
227263
228264 /** Different score mode */
229265 public enum Mode
@@ -236,12 +272,12 @@ public enum Mode
236272
237273 public BM25 ()
238274 {
239- this (0.9 , Mode .MINUS );
275+ this (0.9 , Mode .IRDF );
240276 }
241277
242278 public BM25 (final double idfExp )
243279 {
244- this (idfExp , Mode .MINUS );
280+ this (idfExp , Mode .IRDF );
245281 }
246282
247283
@@ -255,8 +291,6 @@ public BM25(final double idfExp, Mode mode)
255291 public void termStart (final long corpusTermFreq , final int corpusTermDocs )
256292 {
257293 super .termStart (corpusTermFreq , corpusTermDocs );
258- this .focusTermFreqAcc = 0L ;
259- this .focusTermDocs = 0 ;
260294 if (corpusDocs <= 0 ) {
261295 this .corpusIdf = 0d ;
262296 return ;
@@ -271,17 +305,16 @@ public double termDocAdd(final long docTermFreq, final long docTokens, final boo
271305 {
272306 if (docTermFreq <= 0L || docTokens <= 0L || docTokensAvg <= 0d )
273307 return 0d ;
308+
309+ super .termDocAdd (docTermFreq , docTokens , inFocus );
310+
274311 final double tf = (double ) docTermFreq ;
275312 final double norm = k1 * (1d - b + b * ((double ) docTokens / docTokensAvg ));
276313 final double local = (tf * (k1 + 1d )) / (tf + norm );
277- if (inFocus ) {
278- acc += local ;
279- focusTermFreqAcc += docTermFreq ;
280- focusTermDocs ++;
281- } else {
282- otherAcc += local ;
283- }
284- collectCount ++;
314+
315+ if (inFocus ) acc += local ;
316+ else otherAcc += local ;
317+
285318 return local ;
286319 }
287320
@@ -322,25 +355,18 @@ public double termDocAdd(final long docTermFreq, final long docTokens, final boo
322355 @ Override
323356 public double termScore ()
324357 {
358+ // no contrast, return classical BM25
359+ if (!hasFocus ) {
360+ return corpusIdf * acc ;
361+ }
362+ final int otherDocs = corpusDocs - focusDocs ; // N - R
363+ final int otherTermDocs = corpusTermDocs - focusTermDocs ; // n - r
364+ Mode mode = this .mode ;
365+ if (mode == null ) mode = Mode .IRDF ;
325366 switch (mode ) {
326- case IRDF : {
327- int otherDocs = corpusDocs - focusDocs ;
328- int otherTermDocs = corpusTermDocs - focusTermDocs ;
329- if (otherDocs <= 0 )
330- return corpusIdf * acc ;
331- double irdf = Math .pow (
332- Math .log (1.0d + (otherDocs - otherTermDocs + 0.5d ) / (otherTermDocs + 0.5d )),
333- idfExp );
334- return irdf * acc ;
335- }
336367 case RSJ : {
337- final int otherDocs = corpusDocs - focusDocs ; // N - R
338- final int otherTermDocs = corpusTermDocs - focusTermDocs ; // n - r
339368 final int focusNonTermDocs = focusDocs - focusTermDocs ; // R - r
340369 final int otherNonTermDocs = otherDocs - otherTermDocs ; // (N - R) - (n - r)
341- // no part
342- if (otherDocs <= 0 )
343- return corpusIdf * acc ;
344370 if (otherDocs < 0 || otherTermDocs < 0 || focusNonTermDocs < 0 || otherNonTermDocs < 0 ) {
345371 // should throw exception here, no?
346372 return 0d ;
@@ -353,14 +379,20 @@ public double termScore()
353379 final double rsjWeighted = Math .copySign (Math .pow (Math .abs (rsj ), idfExp ), rsj );
354380 return rsjWeighted * acc ;
355381 }
382+ case IRDF : {
383+ double irdf = Math .pow (
384+ Math .log (1.0d + (otherDocs - otherTermDocs + 0.5d ) / (otherTermDocs + 0.5d )),
385+ idfExp );
386+ return irdf * acc ;
387+ }
356388 case FACTOR :
357- if (focusTermFreqAcc == 0 || focusTokens <= 0 )
389+ if (focusTermFreq == 0 || focusTokens <= 0 )
358390 return 0d ;
359- double relFocus = (double ) focusTermFreqAcc / focusTokens ;
391+ double relFocus = (double ) focusTermFreq / focusTokens ;
360392 double relCorpus = (double ) corpusTermFreq / corpusTokens ;
361393 if (relCorpus <= 0d )
362394 return 0d ;
363- return corpusIdf * acc * Math .log (relFocus / relCorpus ) * Math .log (focusTermFreqAcc );
395+ return corpusIdf * acc * Math .log (relFocus / relCorpus ) * Math .log (focusTermFreq );
364396 case WEIGHTED :
365397 final double wFocus = 1.0 ;
366398 final double wRest = -2.0 ; // or whatever you want to try
@@ -385,13 +417,15 @@ public boolean equals(final Object o)
385417 return true ;
386418 if (!(o instanceof BM25 other ))
387419 return false ;
388- return Double .compare (this .idfExp , other .idfExp ) == 0 ;
420+ return Double .compare (this .idfExp , other .idfExp ) == 0 && this . mode == other . mode ;
389421 }
390422
391423 @ Override
392424 public int hashCode ()
393425 {
394- return Double .hashCode (idfExp );
426+ int h = Double .hashCode (idfExp );
427+ h = 31 * h + mode .hashCode ();
428+ return h ;
395429 }
396430 }
397431
@@ -455,7 +489,7 @@ public double termDocAdd(final long docTermFreq, final long docTokens, final boo
455489 restTermDocsCount ++;
456490 }
457491
458- collectCount ++;
492+ termDocs ++;
459493 return local ;
460494 }
461495
0 commit comments