|
| 1 | +package com.github.oeuvres.alix.util; |
| 2 | + |
| 3 | +import org.openjdk.jmh.annotations.*; |
| 4 | +import org.openjdk.jmh.infra.Blackhole; |
| 5 | + |
| 6 | + |
| 7 | +import java.util.*; |
| 8 | +import java.util.concurrent.TimeUnit; |
| 9 | + |
| 10 | +@BenchmarkMode(Mode.Throughput) |
| 11 | +@OutputTimeUnit(TimeUnit.MICROSECONDS) |
| 12 | +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) |
| 13 | +@Measurement(iterations = 8, time = 1, timeUnit = TimeUnit.SECONDS) |
| 14 | +@Fork(2) |
| 15 | +public class PosLookupBenchmark |
| 16 | +{ |
| 17 | + |
| 18 | + /** |
| 19 | + * If you want to "pass an array of tokens", put it here (one POS tag per |
| 20 | + * lexicon line). When non-null, it overrides file loading and synthetic |
| 21 | + * generation. |
| 22 | + * |
| 23 | + * Example: PosLookupWorkbench.POS_PER_LINE = new String[] { "NOUN", "PUNCT", |
| 24 | + * "VERB", ... }; |
| 25 | + */ |
| 26 | + public static volatile String[] POS_PER_LINE = null; |
| 27 | + |
| 28 | + @State(Scope.Benchmark) |
| 29 | + public static class BenchState |
| 30 | + { |
| 31 | + |
| 32 | + /** |
| 33 | + * If POS_PER_LINE is null and -DposFile is not set, we generate this many |
| 34 | + * lines. |
| 35 | + */ |
| 36 | + @Param({ "500000" }) |
| 37 | + public int lines; |
| 38 | + |
| 39 | + /** |
| 40 | + * Synthetic tag-set used only when neither POS_PER_LINE nor -DposFile is |
| 41 | + * provided. |
| 42 | + */ |
| 43 | + public static final String[] DEFAULT_TAGS = { |
| 44 | + "TOKEN", |
| 45 | + "UNKNOWN", |
| 46 | + "TEST", |
| 47 | + "XML", |
| 48 | + "STOP", |
| 49 | + "NOSTOP", |
| 50 | + "LOC", |
| 51 | + "PUNCT", |
| 52 | + "PUNCTsection", |
| 53 | + "PUNCTpara", |
| 54 | + "PUNCTsent", |
| 55 | + "PUNCTclause", |
| 56 | + "VERB", |
| 57 | + "AUX", |
| 58 | + "VERBinf", |
| 59 | + "VERBpartpast", |
| 60 | + "VERBpartpres", |
| 61 | + "NOUN", |
| 62 | + "ADJ", |
| 63 | + "PROPN", |
| 64 | + "PROPNprs", |
| 65 | + "PROPNgivmasc", |
| 66 | + "PROPNgivfem", |
| 67 | + "PROPNgeo", |
| 68 | + "PROPNorg", |
| 69 | + "PROPNevent", |
| 70 | + "PROPNauthor", |
| 71 | + "PROPNfict", |
| 72 | + "PROPNtitle", |
| 73 | + "PROPNspec", |
| 74 | + "PROPNpeople", |
| 75 | + "PROPNgod", |
| 76 | + "ADV", |
| 77 | + "ADVint", |
| 78 | + "ADVneg", |
| 79 | + "PART", |
| 80 | + "ADVsit", |
| 81 | + "ADVasp", |
| 82 | + "ADVdeg", |
| 83 | + "DET", |
| 84 | + "DETart", |
| 85 | + "DETdem", |
| 86 | + "DETind", |
| 87 | + "DETint", |
| 88 | + "DETneg", |
| 89 | + "DETprs", |
| 90 | + "ADP_DET", |
| 91 | + "PRON", |
| 92 | + "PRONdem", |
| 93 | + "PRONind", |
| 94 | + "PRONint", |
| 95 | + "PRONneg", |
| 96 | + "PRONprs", |
| 97 | + "PRONrel", |
| 98 | + "ADP_PRON", |
| 99 | + "ADP", |
| 100 | + "CCONJ", |
| 101 | + "SCONJ", |
| 102 | + "NUM", |
| 103 | + "NUMord", |
| 104 | + "SYM", |
| 105 | + "DIGIT", |
| 106 | + "MATH", |
| 107 | + "UNIT", |
| 108 | + "REF", |
| 109 | + "X", |
| 110 | + "INTJ", |
| 111 | + "ABBR", |
| 112 | + "MG", |
| 113 | + }; |
| 114 | + |
| 115 | + public static record TagCount(String tag, int count) {} |
| 116 | + static final TagCount[] TAG_DIST = { |
| 117 | + new TagCount("VERB", 306_226), |
| 118 | + new TagCount("NOUN", 112_686), |
| 119 | + new TagCount("ADJ", 79_593), |
| 120 | + new TagCount("VERBpartpast", 29_612), |
| 121 | + new TagCount("VERBpartpres", 8_207), |
| 122 | + new TagCount("ADV", 2_348), |
| 123 | + new TagCount("NUM", 214), |
| 124 | + new TagCount("INTJ", 166), |
| 125 | + new TagCount("AUX", 130), |
| 126 | + new TagCount("ADP", 68), |
| 127 | + new TagCount("PRONprs", 51), |
| 128 | + new TagCount("ADVsit", 33), |
| 129 | + new TagCount("PRONdem", 27), |
| 130 | + new TagCount("ADVasp", 24), |
| 131 | + new TagCount("ADVdeg", 23), |
| 132 | + new TagCount("PRONind", 22), |
| 133 | + new TagCount("DETind", 22), |
| 134 | + new TagCount("PRONrel", 18), |
| 135 | + new TagCount("SCONJ", 16), |
| 136 | + new TagCount("PRONint", 16), |
| 137 | + new TagCount("DETprs", 15), |
| 138 | + new TagCount("DETart", 11), |
| 139 | + new TagCount("DETdem", 10), |
| 140 | + new TagCount("CCONJ", 10), |
| 141 | + new TagCount("ADVneg", 9), |
| 142 | + new TagCount("DETneg", 8), |
| 143 | + new TagCount("ADP+DET", 7), |
| 144 | + new TagCount("ADP+PRON", 6), |
| 145 | + new TagCount("PRONneg", 5), |
| 146 | + new TagCount("ADVint", 4), |
| 147 | + new TagCount("PRON", 2), |
| 148 | + new TagCount("DETprep", 1), |
| 149 | + }; |
| 150 | + |
| 151 | + static String[] generateFromCounts(final int lines, final TagCount[] dist, final long seed) { |
| 152 | + if (dist == null || dist.length == 0) throw new IllegalArgumentException("empty dist"); |
| 153 | + |
| 154 | + final int n = dist.length; |
| 155 | + long total = 0; |
| 156 | + final long[] cdf = new long[n]; |
| 157 | + |
| 158 | + for (int i = 0; i < n; i++) { |
| 159 | + final int c = dist[i].count(); |
| 160 | + if (c <= 0) throw new IllegalArgumentException("count<=0 for " + dist[i].tag()); |
| 161 | + total += c; |
| 162 | + cdf[i] = total; |
| 163 | + } |
| 164 | + |
| 165 | + final java.util.Random rnd = new java.util.Random(seed); |
| 166 | + final String[] out = new String[lines]; |
| 167 | + |
| 168 | + for (int i = 0; i < lines; i++) { |
| 169 | + final long x = (long) (rnd.nextDouble() * total); // [0,total) |
| 170 | + int j = java.util.Arrays.binarySearch(cdf, x); |
| 171 | + if (j < 0) j = -j - 1; |
| 172 | + out[i] = dist[j].tag(); |
| 173 | + } |
| 174 | + return out; |
| 175 | + } |
| 176 | + |
| 177 | + |
| 178 | + // Dataset: one tag per line |
| 179 | + String[] posPerLine; |
| 180 | + |
| 181 | + // Same dataset as CSV-like slices in a single char buffer |
| 182 | + char[] buf; |
| 183 | + int[] off; |
| 184 | + short[] len; |
| 185 | + |
| 186 | + // Derived dictionary: tag -> int code |
| 187 | + Map<String, Integer> codeByName; |
| 188 | + |
| 189 | + // CharsDic + ord->code indirection |
| 190 | + CharsDic dic; |
| 191 | + int[] codeByOrd; |
| 192 | + |
| 193 | + // Stats (for sanity) |
| 194 | + int distinct; |
| 195 | + |
| 196 | + @Setup(Level.Trial) |
| 197 | + public void setup() throws Exception |
| 198 | + { |
| 199 | + posPerLine = generateFromCounts(lines, TAG_DIST, 1L); |
| 200 | + |
| 201 | + // Build slice representation (buf/off/len) once. |
| 202 | + buildSlices(posPerLine); |
| 203 | + |
| 204 | + // Derive dictSize and assign int codes from DISTINCT tags. |
| 205 | + final HashMap<String, Integer> tmp = new HashMap<>(64); |
| 206 | + for (String t : posPerLine) { |
| 207 | + // If your input can contain whitespace, trim it here: |
| 208 | + // t = t.trim(); |
| 209 | + if (!tmp.containsKey(t)) |
| 210 | + tmp.put(t, tmp.size()); |
| 211 | + } |
| 212 | + distinct = tmp.size(); |
| 213 | + codeByName = Map.copyOf(tmp); |
| 214 | + |
| 215 | + // Build CharsDic from distinct tags and create ord->code map |
| 216 | + dic = new CharsDic(Math.max(1, distinct)); |
| 217 | + codeByOrd = new int[distinct]; |
| 218 | + Arrays.fill(codeByOrd, -1); |
| 219 | + |
| 220 | + for (Map.Entry<String, Integer> e : tmp.entrySet()) { |
| 221 | + final String tag = e.getKey(); |
| 222 | + final int code = e.getValue(); |
| 223 | + final char[] a = tag.toCharArray(); |
| 224 | + final int ord = dic.add(a, 0, a.length); |
| 225 | + // add() returns ord>=0 for new, or -(ord)-1 for existing |
| 226 | + final int o = (ord >= 0) ? ord : (-ord - 1); |
| 227 | + if (o >= codeByOrd.length) { |
| 228 | + // defensive; should not happen |
| 229 | + throw new IllegalStateException("ord out of range: " + o + " distinct=" + distinct); |
| 230 | + } |
| 231 | + codeByOrd[o] = code; |
| 232 | + } |
| 233 | + dic.freeze(); |
| 234 | + |
| 235 | + // Sanity: every ord must map to some code |
| 236 | + for (int i = 0; i < distinct; i++) { |
| 237 | + if (codeByOrd[i] < 0) |
| 238 | + throw new IllegalStateException("Missing code for ord=" + i); |
| 239 | + } |
| 240 | + } |
| 241 | + |
| 242 | + private void buildSlices(String[] arr) |
| 243 | + { |
| 244 | + final int n = arr.length; |
| 245 | + off = new int[n]; |
| 246 | + len = new short[n]; |
| 247 | + |
| 248 | + int totalChars = 0; |
| 249 | + for (String s : arr) |
| 250 | + totalChars += s.length(); |
| 251 | + |
| 252 | + buf = new char[totalChars]; |
| 253 | + int p = 0; |
| 254 | + for (int i = 0; i < n; i++) { |
| 255 | + final String s = arr[i]; |
| 256 | + final int L = s.length(); |
| 257 | + off[i] = p; |
| 258 | + len[i] = (short) L; |
| 259 | + s.getChars(0, L, buf, p); |
| 260 | + p += L; |
| 261 | + } |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + @State(Scope.Thread) |
| 266 | + public static class ThreadState |
| 267 | + { |
| 268 | + int pos; |
| 269 | + } |
| 270 | + |
| 271 | + private static final int OPS = 1024; |
| 272 | + |
| 273 | + /** |
| 274 | + * Baseline: parser already produced a String per line. No extra allocation per |
| 275 | + * lookup. |
| 276 | + */ |
| 277 | + @Benchmark |
| 278 | + @OperationsPerInvocation(OPS) |
| 279 | + public void hashmap_get_prebuiltString(BenchState s, ThreadState t, Blackhole bh) |
| 280 | + { |
| 281 | + int p = t.pos; |
| 282 | + final int n = s.posPerLine.length; |
| 283 | + for (int i = 0; i < OPS; i++) { |
| 284 | + final String tag = s.posPerLine[p++ % n]; |
| 285 | + final Integer code = s.codeByName.get(tag); // should hit |
| 286 | + bh.consume(code); |
| 287 | + } |
| 288 | + t.pos = p; |
| 289 | + } |
| 290 | + |
| 291 | + /** |
| 292 | + * Hot-path you are concerned about: CSV-like slice -> new String -> |
| 293 | + * HashMap.get. Allocates a new String (and backing storage) per token. |
| 294 | + */ |
| 295 | + @Benchmark |
| 296 | + @OperationsPerInvocation(OPS) |
| 297 | + public void hashmap_get_newStringFromSlice(BenchState s, ThreadState t, Blackhole bh) |
| 298 | + { |
| 299 | + int p = t.pos; |
| 300 | + final int n = s.off.length; |
| 301 | + for (int i = 0; i < OPS; i++) { |
| 302 | + final int idx = p++ % n; |
| 303 | + final int o = s.off[idx]; |
| 304 | + final int L = s.len[idx] & 0xFFFF; |
| 305 | + final Integer code = s.codeByName.get(new String(s.buf, o, L)); // alloc every time |
| 306 | + bh.consume(code); |
| 307 | + } |
| 308 | + t.pos = p; |
| 309 | + } |
| 310 | + |
| 311 | + /** Allocation-free: slice -> CharsDic.find -> ord->code. */ |
| 312 | + @Benchmark |
| 313 | + @OperationsPerInvocation(OPS) |
| 314 | + public void charsDic_findFromSlice(BenchState s, ThreadState t, Blackhole bh) |
| 315 | + { |
| 316 | + int p = t.pos; |
| 317 | + final int n = s.off.length; |
| 318 | + for (int i = 0; i < OPS; i++) { |
| 319 | + final int idx = p++ % n; |
| 320 | + final int o = s.off[idx]; |
| 321 | + final int L = s.len[idx] & 0xFFFF; |
| 322 | + |
| 323 | + final int ord = s.dic.find(s.buf, o, L); // should hit |
| 324 | + final int code = s.codeByOrd[ord]; |
| 325 | + bh.consume(code); |
| 326 | + } |
| 327 | + t.pos = p; |
| 328 | + } |
| 329 | + |
| 330 | + |
| 331 | +} |
0 commit comments