|
| 1 | +package org.beehive.gpullama3.auxiliary; |
| 2 | + |
| 3 | +import org.beehive.gpullama3.auxiliary.metrics.GitHubMetricsRenderer; |
| 4 | +import org.beehive.gpullama3.auxiliary.metrics.HumanMetricsRenderer; |
| 5 | +import org.beehive.gpullama3.auxiliary.metrics.JsonMetricsRenderer; |
| 6 | +import org.beehive.gpullama3.auxiliary.metrics.MetricsRenderer; |
| 7 | +import org.beehive.gpullama3.auxiliary.metrics.RunMetricsSnapshot; |
| 8 | + |
| 9 | +import java.io.IOException; |
| 10 | +import java.io.UncheckedIOException; |
| 11 | +import java.nio.file.Files; |
| 12 | +import java.nio.file.Path; |
| 13 | + |
| 14 | +/** |
| 15 | + * Singleton that accumulates fine-grained performance metrics across one inference run. |
| 16 | + * |
| 17 | + * <p>Metrics are set incrementally by different layers of the stack:</p> |
| 18 | + * <ul> |
| 19 | + * <li>{@link #setLoadDuration} — called from {@code ModelLoader}</li> |
| 20 | + * <li>{@link #setTornadoMetrics} — called from TornadoVM plan constructors</li> |
| 21 | + * <li>{@link #setInferenceMetrics} — called from InferenceEngine variants at end of generation</li> |
| 22 | + * <li>{@link #setHasPrefillPhase} — called from prefill-decode engine variants</li> |
| 23 | + * </ul> |
| 24 | + * |
| 25 | + * <p>All durations are stored in nanoseconds. {@link #printMetrics()} builds an immutable |
| 26 | + * {@link RunMetricsSnapshot}, selects a {@link MetricsRenderer}, and writes to the configured sink.</p> |
| 27 | + * |
| 28 | + * <p>Configurable via system properties:</p> |
| 29 | + * <ul> |
| 30 | + * <li>{@code llama.metrics.format} — {@code human} (default) | {@code json} | {@code github}</li> |
| 31 | + * <li>{@code llama.metrics.output} — {@code stderr} (default) | {@code stdout} | {@code file}</li> |
| 32 | + * <li>{@code llama.metrics.file} — target path when {@code output=file}</li> |
| 33 | + * </ul> |
| 34 | + */ |
| 35 | +public final class RunMetrics { |
| 36 | + |
| 37 | + // ── Core metrics (nanoseconds) ──────────────────────────────────────────── |
| 38 | + private long totalDurationNs; |
| 39 | + private long loadDurationNs; |
| 40 | + private int promptEvalCount; |
| 41 | + private long promptEvalDurationNs; |
| 42 | + private int evalCount; |
| 43 | + private long evalDurationNs; |
| 44 | + private boolean hasPrefillPhase; |
| 45 | + |
| 46 | + // ── TornadoVM-specific metrics (nanoseconds) ────────────────────────────── |
| 47 | + private long tornadoPlanCreationNs; |
| 48 | + private long tornadoJitNs; |
| 49 | + private long readOnlyWeightsCopyInNs; |
| 50 | + |
| 51 | + // ── Singleton ───────────────────────────────────────────────────────────── |
| 52 | + private static final RunMetrics INSTANCE = new RunMetrics(); |
| 53 | + |
| 54 | + private RunMetrics() {} |
| 55 | + |
| 56 | + // ── Setters ─────────────────────────────────────────────────────────────── |
| 57 | + |
| 58 | + /** Records the time spent loading the model file (not including TornadoVM initialisation). */ |
| 59 | + public static void setLoadDuration(long ns) { |
| 60 | + INSTANCE.loadDurationNs = ns; |
| 61 | + } |
| 62 | + |
| 63 | + /** |
| 64 | + * Records TornadoVM-specific initialisation durations. |
| 65 | + * |
| 66 | + * @param planCreationNs task-graph construction ({@code createExecutionPlan()}) |
| 67 | + * @param jitNs JIT compilation ({@code withPreCompilation()}) |
| 68 | + * @param weightCopyNs first-execution weight upload ({@code forceCopyInReadOnlyData()}) |
| 69 | + */ |
| 70 | + public static void setTornadoMetrics(long planCreationNs, long jitNs, long weightCopyNs) { |
| 71 | + INSTANCE.tornadoPlanCreationNs = planCreationNs; |
| 72 | + INSTANCE.tornadoJitNs = jitNs; |
| 73 | + INSTANCE.readOnlyWeightsCopyInNs = weightCopyNs; |
| 74 | + } |
| 75 | + |
| 76 | + /** |
| 77 | + * Records inference-phase durations at the end of a generation run. |
| 78 | + * |
| 79 | + * @param promptCount number of prompt tokens processed (prefill) |
| 80 | + * @param prefillNs wall-clock time spent in the prefill phase |
| 81 | + * @param generatedCount number of tokens generated (decode) |
| 82 | + * @param decodeNs wall-clock time spent in the decode phase |
| 83 | + * @param totalNs total wall-clock time for the full inference call |
| 84 | + */ |
| 85 | + public static void setInferenceMetrics(int promptCount, long prefillNs, |
| 86 | + int generatedCount, long decodeNs, |
| 87 | + long totalNs) { |
| 88 | + INSTANCE.promptEvalCount = promptCount; |
| 89 | + INSTANCE.promptEvalDurationNs = prefillNs; |
| 90 | + INSTANCE.evalCount = generatedCount; |
| 91 | + INSTANCE.evalDurationNs = decodeNs; |
| 92 | + INSTANCE.totalDurationNs = totalNs; |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * Signals that prefill and decode are distinct timed phases. |
| 97 | + * Called by {@code InferenceEngineWithPrefillDecode} and |
| 98 | + * {@code InferenceEngineWithBatchPrefillDecode} before returning. |
| 99 | + */ |
| 100 | + public static void setHasPrefillPhase(boolean value) { |
| 101 | + INSTANCE.hasPrefillPhase = value; |
| 102 | + } |
| 103 | + |
| 104 | + // ── Snapshot ────────────────────────────────────────────────────────────── |
| 105 | + |
| 106 | + /** Returns an immutable snapshot of all currently collected metrics. */ |
| 107 | + public static RunMetricsSnapshot snapshot() { |
| 108 | + RunMetrics m = INSTANCE; |
| 109 | + return RunMetricsSnapshot.of( |
| 110 | + m.totalDurationNs, m.loadDurationNs, |
| 111 | + m.promptEvalCount, m.promptEvalDurationNs, |
| 112 | + m.evalCount, m.evalDurationNs, |
| 113 | + m.hasPrefillPhase, |
| 114 | + m.tornadoPlanCreationNs, m.tornadoJitNs, |
| 115 | + m.readOnlyWeightsCopyInNs); |
| 116 | + } |
| 117 | + |
| 118 | + // ── Output ──────────────────────────────────────────────────────────────── |
| 119 | + |
| 120 | + /** |
| 121 | + * Builds a snapshot, selects a renderer based on {@code llama.metrics.format}, |
| 122 | + * and writes the result to the sink configured by {@code llama.metrics.output}. |
| 123 | + */ |
| 124 | + public static void printMetrics() { |
| 125 | + RunMetricsSnapshot snap = snapshot(); |
| 126 | + |
| 127 | + MetricsRenderer renderer = switch (System.getProperty("llama.metrics.format", "human").toLowerCase()) { |
| 128 | + case "json" -> new JsonMetricsRenderer(); |
| 129 | + case "github" -> new GitHubMetricsRenderer(); |
| 130 | + default -> new HumanMetricsRenderer(); |
| 131 | + }; |
| 132 | + |
| 133 | + String rendered = renderer.render(snap); |
| 134 | + |
| 135 | + switch (System.getProperty("llama.metrics.output", "stderr").toLowerCase()) { |
| 136 | + case "stdout" -> System.out.print(rendered); |
| 137 | + case "file" -> writeToFile(rendered); |
| 138 | + default -> System.err.print(rendered); |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + private static void writeToFile(String content) { |
| 143 | + String filePath = System.getProperty("llama.metrics.file"); |
| 144 | + if (filePath == null || filePath.isBlank()) { |
| 145 | + throw new IllegalStateException( |
| 146 | + "llama.metrics.output=file requires llama.metrics.file to be set"); |
| 147 | + } |
| 148 | + Path path = Path.of(filePath); |
| 149 | + try { |
| 150 | + Path parent = path.getParent(); |
| 151 | + if (parent != null) Files.createDirectories(parent); |
| 152 | + Files.writeString(path, content); |
| 153 | + } catch (IOException e) { |
| 154 | + throw new UncheckedIOException("Failed to write metrics to " + filePath, e); |
| 155 | + } |
| 156 | + } |
| 157 | +} |
0 commit comments