Skip to content

Commit cff2265

Browse files
committed
feat: add some unit tests
1 parent 35ef45a commit cff2265

File tree

4 files changed

+844
-0
lines changed

4 files changed

+844
-0
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package org.grobid.core.engines.tagging.delft;
2+
3+
import org.junit.Before;
4+
import org.junit.Test;
5+
6+
import static org.junit.Assert.*;
7+
8+
/**
9+
* Unit tests for CRFDecoder Viterbi decoding.
10+
*
11+
* Tests verify:
12+
* - Basic Viterbi decoding finds optimal path
13+
* - Mask handling for variable length sequences
14+
* - Transition matrix usage
15+
* - Start/end transitions
16+
* - Batch decoding
17+
*/
18+
public class CRFDecoderTest {
19+
20+
private CRFDecoder decoder;
21+
22+
@Before
23+
public void setUp() {
24+
// Create a simple 3-tag CRF decoder (O, B-TITLE, I-TITLE)
25+
// Transition matrix [from_tag][to_tag]
26+
float[][] transitions = {
27+
// to: O, B-TITLE, I-TITLE
28+
{ 0.5f, 0.3f, -1.0f }, // from O: prefer O or B-TITLE, penalize I-TITLE
29+
{ 0.2f, 0.1f, 0.6f }, // from B-TITLE: prefer I-TITLE
30+
{ 0.3f, 0.1f, 0.5f } // from I-TITLE: prefer continuing I-TITLE
31+
};
32+
33+
// Start transitions: prefer starting with O or B-TITLE
34+
float[] startTransitions = { 0.5f, 0.4f, -1.0f };
35+
36+
// End transitions: all tags can end
37+
float[] endTransitions = { 0.0f, 0.0f, 0.0f };
38+
39+
decoder = new CRFDecoder(transitions, startTransitions, endTransitions);
40+
}
41+
42+
@Test
43+
public void testGetNumTags() {
44+
assertEquals(3, decoder.getNumTags());
45+
}
46+
47+
/**
48+
* Test basic decoding with strong emissions for each tag.
49+
*/
50+
@Test
51+
public void testDecode_followsStrongEmissions() {
52+
// Emissions strongly favor: [O, B-TITLE, I-TITLE, O]
53+
float[][] emissions = {
54+
{ 2.0f, -1.0f, -1.0f }, // Position 0: strongly O
55+
{ -1.0f, 2.0f, -1.0f }, // Position 1: strongly B-TITLE
56+
{ -1.0f, -1.0f, 2.0f }, // Position 2: strongly I-TITLE
57+
{ 2.0f, -1.0f, -1.0f }, // Position 3: strongly O
58+
};
59+
60+
int[] result = decoder.decode(emissions, null);
61+
62+
assertEquals(4, result.length);
63+
assertEquals(0, result[0]); // O
64+
assertEquals(1, result[1]); // B-TITLE
65+
assertEquals(2, result[2]); // I-TITLE
66+
assertEquals(0, result[3]); // O
67+
}
68+
69+
/**
70+
* Test that transitions influence decoding when emissions are ambiguous.
71+
*/
72+
@Test
73+
public void testDecode_transitionsInfluenceDecoding() {
74+
// Emissions are all equal - transitions should decide
75+
float[][] emissions = {
76+
{ 0.0f, 0.0f, 0.0f }, // Position 0: ambiguous
77+
{ 0.0f, 0.0f, 0.0f }, // Position 1: ambiguous
78+
};
79+
80+
int[] result = decoder.decode(emissions, null);
81+
82+
assertEquals(2, result.length);
83+
// With our transition matrix, starting with O is preferred
84+
// (startTransitions[0]=0.5)
85+
// and O->O has good transition (0.5)
86+
assertEquals(0, result[0]); // Should start with O
87+
}
88+
89+
/**
90+
* Test that I-TITLE cannot start a sequence (penalized by startTransitions).
91+
*/
92+
@Test
93+
public void testDecode_cannotStartWithContinuation() {
94+
// Position 0 emissions favor I-TITLE, but start transitions penalize it
95+
float[][] emissions = {
96+
{ 0.0f, 0.0f, 0.5f }, // Slightly favor I-TITLE
97+
};
98+
99+
int[] result = decoder.decode(emissions, null);
100+
101+
// Should NOT be I-TITLE (index 2) because start transitions penalize it
102+
assertNotEquals(2, result[0]);
103+
}
104+
105+
/**
106+
* Test decoding with mask - only valid positions are decoded.
107+
*/
108+
@Test
109+
public void testDecode_respectsMask() {
110+
float[][] emissions = {
111+
{ 2.0f, -1.0f, -1.0f }, // Position 0: O
112+
{ -1.0f, 2.0f, -1.0f }, // Position 1: B-TITLE
113+
{ -1.0f, -1.0f, 2.0f }, // Position 2: I-TITLE (masked out)
114+
{ 2.0f, -1.0f, -1.0f }, // Position 3: O (masked out)
115+
};
116+
117+
boolean[] mask = { true, true, false, false };
118+
119+
int[] result = decoder.decode(emissions, mask);
120+
121+
// Only first 2 positions should be decoded
122+
assertEquals(2, result.length);
123+
assertEquals(0, result[0]); // O
124+
assertEquals(1, result[1]); // B-TITLE
125+
}
126+
127+
/**
128+
* Test decoding empty sequence (all masked out).
129+
*/
130+
@Test
131+
public void testDecode_emptySequence() {
132+
float[][] emissions = {
133+
{ 2.0f, -1.0f, -1.0f },
134+
{ -1.0f, 2.0f, -1.0f },
135+
};
136+
137+
boolean[] mask = { false, false };
138+
139+
int[] result = decoder.decode(emissions, mask);
140+
141+
assertEquals(0, result.length);
142+
}
143+
144+
/**
145+
* Test batch decoding.
146+
*/
147+
@Test
148+
public void testDecodeBatch() {
149+
float[][][] emissions = {
150+
// Sequence 0: [O, B-TITLE]
151+
{
152+
{ 2.0f, -1.0f, -1.0f },
153+
{ -1.0f, 2.0f, -1.0f },
154+
},
155+
// Sequence 1: [B-TITLE, I-TITLE]
156+
{
157+
{ -1.0f, 2.0f, -1.0f },
158+
{ -1.0f, -1.0f, 2.0f },
159+
}
160+
};
161+
162+
int[][] results = decoder.decodeBatch(emissions, null);
163+
164+
assertEquals(2, results.length);
165+
166+
assertEquals(2, results[0].length);
167+
assertEquals(0, results[0][0]); // O
168+
assertEquals(1, results[0][1]); // B-TITLE
169+
170+
assertEquals(2, results[1].length);
171+
assertEquals(1, results[1][0]); // B-TITLE
172+
assertEquals(2, results[1][1]); // I-TITLE
173+
}
174+
175+
/**
176+
* Test batch decoding with masks.
177+
*/
178+
@Test
179+
public void testDecodeBatch_withMasks() {
180+
float[][][] emissions = {
181+
// Sequence 0: 3 positions, but only 2 valid
182+
{
183+
{ 2.0f, -1.0f, -1.0f },
184+
{ -1.0f, 2.0f, -1.0f },
185+
{ -1.0f, -1.0f, 2.0f },
186+
},
187+
// Sequence 1: 3 positions, only 1 valid
188+
{
189+
{ 2.0f, -1.0f, -1.0f },
190+
{ -1.0f, 2.0f, -1.0f },
191+
{ -1.0f, -1.0f, 2.0f },
192+
}
193+
};
194+
195+
boolean[][] masks = {
196+
{ true, true, false },
197+
{ true, false, false }
198+
};
199+
200+
int[][] results = decoder.decodeBatch(emissions, masks);
201+
202+
assertEquals(2, results[0].length); // 2 valid positions
203+
assertEquals(1, results[1].length); // 1 valid position
204+
}
205+
206+
/**
207+
* Test single position decoding.
208+
*/
209+
@Test
210+
public void testDecode_singlePosition() {
211+
float[][] emissions = {
212+
{ -1.0f, 2.0f, -1.0f }, // Only B-TITLE
213+
};
214+
215+
int[] result = decoder.decode(emissions, null);
216+
217+
assertEquals(1, result.length);
218+
assertEquals(1, result[0]); // B-TITLE
219+
}
220+
}

0 commit comments

Comments
 (0)