Skip to content

Commit ab3d6b2

Browse files
committed
TermScorer works
1 parent a39b929 commit ab3d6b2

3 files changed

Lines changed: 393 additions & 350 deletions

File tree

Lines changed: 161 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,229 +1,267 @@
11
package com.github.oeuvres.alix.lucene.terms;
22

33
/**
4-
* Local scorer for one term on one part.
4+
* Local scorer for one term across documents (or parts).
55
*
66
* <p>Intended lifecycle:</p>
77
* <ol>
8-
* <li>prepare one scorer instance for one term with corpus-level statistics,</li>
9-
* <li>call {@link #score(long, long)} for each part,</li>
10-
* <li>aggregate local part scores outside this class.</li>
8+
* <li>call {@link #corpus(long, int)} once with corpus-level statistics,</li>
9+
* <li>for each term:
10+
* <ol>
11+
* <li>call {@link #term(long, int)} — resets the accumulator,</li>
12+
* <li>call {@link #collect(long, long)} for each document/part,</li>
13+
* <li>call {@link #result()} to obtain the aggregated score.</li>
14+
* </ol>
15+
* </li>
1116
* </ol>
1217
*
13-
* <p>This class is stateful. One instance must not be reused concurrently
14-
* for different terms.</p>
18+
* <p>{@link #score(long, long)} remains available for callers who need
19+
* the raw local score without accumulation.</p>
20+
*
21+
* <p>This class is stateful. One instance must not be reused concurrently.</p>
1522
*/
1623
public abstract class TermScorer {
17-
/**
18-
* Aggregation rule used to reduce local part scores to one score per term.
19-
*/
20-
public enum Aggregation {
21-
/** Sum local scores over all parts. */
22-
SUM,
23-
24-
/** Sum only positive local scores. */
25-
SUM_POSITIVE,
2624

27-
/** Maximum local score over all parts. */
28-
MAX,
25+
// =========================================================================
26+
// Corpus-level state (set once)
27+
// =========================================================================
2928

30-
/** Maximum positive local score; negative local scores are ignored. */
31-
MAX_POSITIVE,
32-
33-
/** Arithmetic mean of local scores over all parts. */
34-
MEAN
35-
}
3629
/** Total token count of the full corpus/field. */
3730
protected long corpusTokens;
38-
39-
/** */
40-
protected int corpusPartCount;
41-
4231

43-
/** Cached idf-like value derived from corpus statistics. */
44-
protected double corpusIdf;
32+
/** Number of scoring units (documents or parts). */
33+
protected int corpusPartCount;
4534

46-
/** Average token count of one part. */
35+
/** Average token count per scoring unit. */
4736
protected double partTokensAvg;
48-
49-
/** Total occurrences of the current term in the full corpus/field. */
37+
38+
// =========================================================================
39+
// Term-level state (reset per term)
40+
// =========================================================================
41+
42+
/** Total occurrences of the current term in the corpus. */
5043
protected long corpusTermFreq;
5144

52-
/** Number of corpus documents containing the current term. */
45+
/** Number of scoring units containing the current term. */
5346
protected int corpusTermDocs;
5447

55-
/** Global relative frequency of the current term in the corpus. */
48+
/** Relative frequency of the current term: corpusTermFreq / corpusTokens. */
5649
protected double corpusTermRate;
5750

51+
/** Cached IDF-like factor, computed per term by subclasses that need it. */
52+
protected double corpusIdf;
53+
54+
// =========================================================================
55+
// Accumulator state (reset per term, updated per collect)
56+
// =========================================================================
57+
58+
/** Running accumulator. Semantics depend on the subclass. */
59+
protected double acc;
60+
61+
/** Number of scoring units observed via {@link #collect}. */
62+
protected int collectCount;
63+
64+
// =========================================================================
65+
// Corpus-level setup
66+
// =========================================================================
5867

5968
/**
60-
* Prepare this scorer for one term.
69+
* Set corpus-level statistics. Must be called once before any
70+
* {@link #term(long, int)} call.
6171
*
62-
* @param corpusTermFreq total occurrences of the term in the corpus
63-
* @param corpusTermDocs number of corpus documents containing the term
64-
* @param corpusTokens total token count in the corpus
65-
* @param corpusDocs total live document count in the corpus
66-
* @param avgPartTokens average token count of one part
72+
* @param corpusTokens total token count in the corpus
73+
* @param corpusPartCount number of scoring units (documents or parts)
6774
*/
68-
public final void corpus(
69-
final long corpusTokens,
70-
final int corpusPartCount
71-
) {
75+
public final void corpus(final long corpusTokens, final int corpusPartCount) {
7276
this.corpusTokens = corpusTokens;
7377
this.corpusPartCount = corpusPartCount;
74-
this.partTokensAvg = (double) corpusTokens / (double) corpusPartCount;
75-
76-
this.corpusTermRate = 0d;
77-
this.corpusIdf = 0d;
78-
79-
configure();
78+
this.partTokensAvg = (corpusPartCount > 0)
79+
? (double) corpusTokens / (double) corpusPartCount
80+
: 0d;
8081
}
8182

83+
// =========================================================================
84+
// Term-level setup (resets accumulator)
85+
// =========================================================================
86+
8287
/**
83-
* Initialize the global term rate of the current term:
84-
* corpusTermFreq / corpusTokens.
88+
* Prepare this scorer for a new term. Resets the accumulator.
89+
*
90+
* <p>Subclasses that compute per-term derived values (e.g. IDF) should
91+
* override this method, call {@code super.term()} first, then set
92+
* their derived fields.</p>
93+
*
94+
* @param corpusTermFreq total occurrences of the term in the corpus
95+
* @param corpusTermDocs number of scoring units containing the term
8596
*/
86-
public void term(
87-
final long corpusTermFreq,
88-
final int corpusTermDocs
89-
) {
97+
public void term(final long corpusTermFreq, final int corpusTermDocs) {
9098
this.corpusTermFreq = corpusTermFreq;
9199
this.corpusTermDocs = corpusTermDocs;
92-
if (corpusTokens <= 0L) {
93-
this.corpusTermRate = 0d;
94-
return;
95-
}
96-
this.corpusTermRate = (double) corpusTermFreq / (double) corpusTokens;
100+
this.corpusTermRate = (corpusTokens > 0L)
101+
? (double) corpusTermFreq / (double) corpusTokens
102+
: 0d;
103+
this.corpusIdf = 0d;
104+
this.acc = accInit();
105+
this.collectCount = 0;
97106
}
98107

108+
// =========================================================================
109+
// Accumulation protocol: accInit / collect / result
110+
// =========================================================================
111+
99112
/**
100-
* Optional hook after corpusTermRate and corpusIdf have been initialized.
113+
* Initial accumulator value before any observation.
114+
* Default is {@code 0.0} (suitable for sum-based aggregation).
115+
*
116+
* @return seed value for the accumulator
101117
*/
102-
protected void configure() {
103-
// no-op
118+
protected double accInit() {
119+
return 0d;
104120
}
105121

106122
/**
107-
* Score one part for the prepared term.
123+
* Finalize and return the aggregated score for the current term.
108124
*
109-
* @param partTermFreq occurrences of the term in the part
110-
* @param partTokens total token count of the part
111-
* @return local score for that part
125+
* <p>Default returns the raw accumulator (= sum).
126+
* Subclasses may override for mean, clamping, etc.</p>
127+
*
128+
* @return aggregated corpus-level score for the current term
112129
*/
113-
public abstract double score(final long partTermFreq, final long partTokens);
130+
public double result() {
131+
return acc;
132+
}
133+
134+
// =========================================================================
135+
// Pure local score (no side effect on accumulator)
136+
// =========================================================================
114137

115138
/**
116-
* Signed G-style contribution against the global corpus expectation.
139+
* Compute the local score for one document/part and fold it into the
140+
* accumulator.
117141
*
118-
* <p>Local expectation in one part:</p>
119-
* <pre>
120-
* partExpectedTermFreq = corpusTermRate * partTokens
121-
* </pre>
142+
* @param partTermFreq occurrences of the term in the document/part
143+
* @param partTokens total token count of the document/part
144+
*/
145+
public abstract double score(final long partTermFreq, final long partTokens);
146+
147+
// =========================================================================
148+
// Concrete scorers
149+
// =========================================================================
150+
151+
/**
152+
* Signed G-test contribution against the corpus expectation.
122153
*
123-
* <p>Score:</p>
124154
* <pre>
125-
* 2 * partTermFreq * ln(partTermFreq / partExpectedTermFreq)
155+
* score = 2 × partTermFreq × ln(partTermFreq / expectedFreq)
126156
* </pre>
127157
*
128-
* <p>Positive when the term is over-represented in the part,
129-
* negative when under-represented.</p>
158+
* <p>Positive when over-represented, negative when under-represented.
159+
* Default aggregation: sum of positive contributions only.</p>
130160
*/
131-
public static final class G extends TermScorer {
161+
public static class G extends TermScorer {
162+
163+
/**
164+
* Only accumulate positive contributions (over-represented parts).
165+
* Negative G values indicate under-representation; including them
166+
* in the sum would dilute the keyword signal.
167+
*/
132168
@Override
133169
public double score(final long partTermFreq, final long partTokens) {
134-
if (partTokens <= 0L || corpusTermRate <= 0d) {
170+
171+
if (partTokens <= 0L || corpusTermRate <= 0d || partTermFreq <= 0L) {
135172
return 0d;
136173
}
137-
138-
final double partExpectedTermFreq = corpusTermRate * (double) partTokens;
139-
140-
if (partExpectedTermFreq <= 0d || partTermFreq <= 0L) {
174+
final double expected = corpusTermRate * (double) partTokens;
175+
if (expected <= 0d) {
141176
return 0d;
142177
}
143-
144-
return 2d * (double) partTermFreq
145-
* Math.log((double) partTermFreq / partExpectedTermFreq);
178+
final double local = 2d * (double) partTermFreq * Math.log((double) partTermFreq / expected);
179+
acc += local;
180+
collectCount++;
181+
return local;
146182
}
147183
}
148184

149185
/**
150186
* Count-form Jaccard coefficient.
151187
*
152-
* <p>This is not an expectation scorer. It treats:</p>
153188
* <pre>
154-
* intersection = partTermFreq
155-
* union = partTokens + corpusTermFreq - partTermFreq
189+
* score = partTermFreq / (partTokens + corpusTermFreq - partTermFreq)
156190
* </pre>
157191
*
158-
* <p>Result is in [0, 1] when inputs are coherent.</p>
192+
* <p>Default aggregation: sum.</p>
159193
*/
160-
public static final class Jaccard extends TermScorer {
194+
public static class Jaccard extends TermScorer {
195+
161196
@Override
162197
public double score(final long partTermFreq, final long partTokens) {
163198
if (partTermFreq <= 0L || partTokens <= 0L || corpusTermFreq <= 0L) {
164199
return 0d;
165200
}
166-
167201
final long union = partTokens + corpusTermFreq - partTermFreq;
168202
if (union <= 0L) {
169203
return 0d;
170204
}
171-
172-
return (double) partTermFreq / (double) union;
205+
final double local = partTermFreq / (double) union;
206+
acc += local;
207+
collectCount++;
208+
return local;
173209
}
174210
}
175211

176212
/**
177-
* BM25-like local score on one part.
213+
* BM25-style scorer.
178214
*
179-
* <p>Length normalization uses avgPartTokens.</p>
215+
* <pre>
216+
* score = IDF × tf × (k1 + 1) / (tf + k1 × (1 - b + b × dl / avgdl))
217+
* </pre>
218+
*
219+
* <p>IDF is computed per term in {@link #term(long, int)}.
220+
* Default aggregation: sum (the "summed BM25" corpus keyword score).</p>
180221
*/
181-
public static final class BM25 extends TermScorer {
182-
private final double k1;
183-
private final double b;
222+
public static class BM25 extends TermScorer {
184223

224+
/** Default IR parameters: k1=1.2, poor effect with aggregation */
225+
private final double k1 = 1.2d;
226+
/** Default IR parameters: b=0.75, poor effect with aggregation */
227+
private final double b = 0.75d;
228+
private final double idfExp;
229+
185230
public BM25() {
186-
this(1.2d, 0.75d);
231+
this(1);
187232
}
188233

189-
public BM25(final double k1, final double b) {
190-
if (k1 < 0d) {
191-
throw new IllegalArgumentException("k1 must be >= 0, got " + k1);
192-
}
193-
if (b < 0d || b > 1d) {
194-
throw new IllegalArgumentException("b must be in [0,1], got " + b);
195-
}
196-
this.k1 = k1;
197-
this.b = b;
234+
/**
235+
* @param k1 term frequency saturation (≥ 0). Lower = faster saturation.
236+
*/
237+
public BM25(final double idfExp) {
238+
this.idfExp = idfExp;
198239
}
199240

200241
@Override
201-
public final void term(
202-
final long corpusTermFreq,
203-
final int corpusTermDocs
204-
) {
242+
public void term(final long corpusTermFreq, final int corpusTermDocs) {
205243
super.term(corpusTermFreq, corpusTermDocs);
206244
if (corpusPartCount <= 0) {
207245
this.corpusIdf = 0d;
208246
return;
209247
}
210-
211-
this.corpusIdf = Math.log(
212-
1.0d + ((double) corpusPartCount - (double) corpusTermDocs + 0.5d)
213-
/ ((double) corpusTermDocs + 0.5d)
214-
);
248+
final double n = corpusPartCount;
249+
final double df = corpusTermDocs;
250+
double rawIdf = Math.log(1.0d + (n - df + 0.5d) / (df + 0.5d));
251+
this.corpusIdf = Math.pow(rawIdf, idfExp);
215252
}
216253

217254
@Override
218255
public double score(final long partTermFreq, final long partTokens) {
219256
if (partTermFreq <= 0L || partTokens <= 0L || partTokensAvg <= 0d || corpusIdf <= 0d) {
220257
return 0d;
221258
}
222-
223259
final double tf = (double) partTermFreq;
224260
final double norm = k1 * (1d - b + b * ((double) partTokens / partTokensAvg));
225-
226-
return corpusIdf * (tf * (k1 + 1d)) / (tf + norm);
261+
final double local = corpusIdf * (tf * (k1 + 1d)) / (tf + norm);
262+
acc += local;
263+
collectCount++;
264+
return local;
227265
}
228266
}
229-
}
267+
}

0 commit comments

Comments
 (0)