Skip to content

Commit 56b4781

Browse files
committed
make learner + OT batch queries to better support parallel oracles
1 parent 160b7b8 commit 56b4781

File tree

2 files changed

+171
-138
lines changed

2 files changed

+171
-138
lines changed

algorithms/active/lstar/src/main/java/de/learnlib/algorithm/lstar/mmlt/ExtensibleLStarMMLT.java

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*/
1616
package de.learnlib.algorithm.lstar.mmlt;
1717

18-
import java.util.Collection;
18+
import java.util.ArrayList;
1919
import java.util.Collections;
2020
import java.util.HashMap;
2121
import java.util.List;
@@ -40,6 +40,7 @@
4040
import de.learnlib.filter.symbol.AcceptAllSymbolFilter;
4141
import de.learnlib.oracle.TimedQueryOracle;
4242
import de.learnlib.query.DefaultQuery;
43+
import de.learnlib.query.Query;
4344
import de.learnlib.statistic.Statistics;
4445
import de.learnlib.statistic.StatisticsCollector;
4546
import de.learnlib.time.MMLTModelParams;
@@ -51,6 +52,7 @@
5152
import net.automatalib.automaton.mmlt.MMLT;
5253
import net.automatalib.automaton.mmlt.TimerInfo;
5354
import net.automatalib.common.util.HashUtil;
55+
import net.automatalib.common.util.collection.IterableUtil;
5456
import net.automatalib.symbol.time.InputSymbol;
5557
import net.automatalib.symbol.time.TimeStepSequence;
5658
import net.automatalib.symbol.time.TimedInput;
@@ -235,22 +237,22 @@ private List<Row<TimedInput<I>>> selectClosingRows(List<List<Row<TimedInput<I>>>
235237

236238
private void updateOutputs() {
237239
// Query output of newly-added transitions:
238-
updateOutputs(this.hypData.getTable().getShortPrefixRows());
239-
updateOutputs(this.hypData.getTable().getLongPrefixRows());
240-
}
240+
MMLTObservationTable<I, O> ot = this.hypData.getTable();
241+
List<OutputQuery<I, O>> queries = new ArrayList<>();
242+
243+
for (Row<TimedInput<I>> row : IterableUtil.concat(ot.getShortPrefixRows(), ot.getLongPrefixRows())) {
244+
Word<TimedInput<I>> label = row.getLabel();
241245

242-
private void updateOutputs(Collection<Row<TimedInput<I>>> rows) {
243-
for (Row<TimedInput<I>> row : rows) {
244-
if (row.getLabel().isEmpty()) {
246+
if (label.isEmpty()) {
245247
continue; // initial state
246248
}
247249

248-
if (this.hypData.getTransitionOutputMap().containsKey(row.getLabel())) {
250+
if (this.hypData.getTransitionOutputMap().containsKey(label)) {
249251
continue; // already queried
250252
}
251253

252-
Word<TimedInput<I>> prefix = row.getLabel().prefix(-1);
253-
TimedInput<I> inputSym = row.getLabel().suffix(1).lastSymbol();
254+
Word<TimedInput<I>> prefix = label.prefix(-1);
255+
TimedInput<I> inputSym = label.lastSymbol();
254256

255257
TimedOutput<O> output;
256258
if (inputSym instanceof TimeStepSequence<I> ws) {
@@ -259,12 +261,17 @@ private void updateOutputs(Collection<Row<TimedInput<I>>> rows) {
259261
assert timerInfo != null;
260262
O combinedOutput = this.hypData.getModelParams().outputCombiner().combineSymbols(timerInfo.outputs());
261263
output = new TimedOutput<>(combinedOutput);
264+
this.hypData.getTransitionOutputMap().put(label, output);
262265
} else {
263-
output = this.timeOracle.answerQuery(prefix, Word.fromLetter(inputSym)).lastSymbol();
266+
queries.add(new OutputQuery<>(label, prefix));
264267
}
268+
}
265269

266-
if (output != null) {
267-
this.hypData.getTransitionOutputMap().put(row.getLabel(), output);
270+
if (!queries.isEmpty()) {
271+
timeOracle.processQueries(queries);
272+
273+
for (OutputQuery<I, O> q : queries) {
274+
q.process(this.hypData.getTransitionOutputMap());
268275
}
269276
}
270277
}
@@ -402,8 +409,8 @@ private void handleMissingTimeoutChange(Row<TimedInput<I>> spRow, TimerInfo<?, O
402409
// If it is a fringe prefix, we need to remove it:
403410
TimerInfo<?, O> lastTimer = locationTimerInfo.getLastTimer();
404411
assert lastTimer != null;
405-
Word<TimedInput<I>> lastTimerTransPrefix = spRow.getLabel().append(TimedInput.step(lastTimer.initial()));
406412
if (!lastTimer.periodic()) {
413+
Word<TimedInput<I>> lastTimerTransPrefix = spRow.getLabel().append(TimedInput.step(lastTimer.initial()));
407414
Row<TimedInput<I>> row = hypData.getTable().getRow(lastTimerTransPrefix);
408415
assert row != null;
409416
if (!row.isShortPrefixRow()) {
@@ -413,16 +420,15 @@ private void handleMissingTimeoutChange(Row<TimedInput<I>> spRow, TimerInfo<?, O
413420
}
414421

415422
// Prefix for timeout-transition of new one-shot timer:
416-
Word<TimedInput<I>> timerTransPrefix = spRow.getLabel().append(TimedInput.step(timeout.initial()));
417-
assert this.hypData.getTable().getRow(timerTransPrefix) == null : "Timer already appears to be one-shot.";
423+
assert this.hypData.getTable().getRow(spRow.getLabel().append(TimedInput.step(timeout.initial()))) == null :
424+
"Timer already appears to be one-shot.";
418425

419426
// Remove all timers with greater timeout (are now redundant):
420-
List<String> subsequentTimers = locationTimerInfo.getSortedTimers()
421-
.stream()
422-
.filter(t -> t.initial() > timeout.initial())
423-
.map(TimerInfo::name)
424-
.toList();
425-
subsequentTimers.forEach(locationTimerInfo::removeTimer);
427+
for (TimerInfo<?, O> t : new ArrayList<>(locationTimerInfo.getSortedTimers())) {
428+
if (t.initial() > timeout.initial()) {
429+
locationTimerInfo.removeTimer(t.name());
430+
}
431+
}
426432

427433
// Change from periodic to one-shot:
428434
locationTimerInfo.setOneShotTimer(timeout.name());
@@ -501,9 +507,7 @@ private static <I, O> MMLTHypothesis<I, O> constructHypothesis(MMLTHypDataContai
501507
}
502508
}
503509
// Ensure initial location:
504-
if (hypothesis.getInitialState() == null) {
505-
throw new IllegalArgumentException("Automaton must have an initial location.");
506-
}
510+
assert hypothesis.getInitialState() != null : "Automaton must have an initial location.";
507511

508512
// 5. Create outgoing transitions for non-delaying inputs:
509513
for (Entry<Integer, Integer> e : stateMap.entrySet()) {
@@ -573,6 +577,44 @@ private static <I, O> MMLTHypothesis<I, O> constructHypothesis(MMLTHypDataContai
573577
return hypothesis;
574578
}
575579

580+
private static final class OutputQuery<I, O> extends Query<TimedInput<I>, Word<TimedOutput<O>>> {
581+
582+
private final Word<TimedInput<I>> label;
583+
private final Word<TimedInput<I>> prefix;
584+
private TimedOutput<O> output;
585+
586+
private OutputQuery(Word<TimedInput<I>> label, Word<TimedInput<I>> prefix) {
587+
this.label = label;
588+
this.prefix = prefix;
589+
}
590+
591+
@Override
592+
public void answer(Word<TimedOutput<O>> output) {
593+
assert output.size() == 1;
594+
this.output = output.firstSymbol();
595+
}
596+
597+
@Override
598+
public Word<TimedInput<I>> getPrefix() {
599+
return prefix;
600+
}
601+
602+
@Override
603+
public Word<TimedInput<I>> getSuffix() {
604+
return Word.fromLetter(label.lastSymbol());
605+
}
606+
607+
/**
608+
* Processes the query result by mapping the given label to the (single) response.
609+
*
610+
* @param outputs
611+
* the output map to write the mapping to
612+
*/
613+
void process(Map<Word<TimedInput<I>>, TimedOutput<O>> outputs) {
614+
outputs.put(label, output);
615+
}
616+
}
617+
576618
static final class BuilderDefaults {
577619

578620
private BuilderDefaults() {
@@ -592,7 +634,7 @@ static <I> MutableSymbolFilter<TimedInput<I>, InputSymbol<I>> symbolFilter() {
592634
}
593635

594636
static AcexAnalyzer analyzer() {
595-
return AcexAnalyzers.LINEAR_BWD;
637+
return AcexAnalyzers.BINARY_SEARCH_BWD;
596638
}
597639
}
598640

0 commit comments

Comments
 (0)