Skip to content

Commit d90b4e3

Browse files
authored
Merge pull request #1073 from synthetichealth/rng
Fix random divergence between runs with the same seeds.
2 parents 58225e0 + aa49445 commit d90b4e3

31 files changed

Lines changed: 295 additions & 251 deletions

src/main/java/org/mitre/synthea/engine/Generator.java

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.LinkedList;
2020
import java.util.List;
2121
import java.util.Map;
22-
import java.util.Random;
2322
import java.util.UUID;
2423
import java.util.concurrent.ExecutorService;
2524
import java.util.concurrent.Executors;
@@ -36,6 +35,7 @@
3635
import org.mitre.synthea.export.CDWExporter;
3736
import org.mitre.synthea.export.Exporter;
3837
import org.mitre.synthea.helpers.Config;
38+
import org.mitre.synthea.helpers.DefaultRandomNumberGenerator;
3939
import org.mitre.synthea.helpers.RandomNumberGenerator;
4040
import org.mitre.synthea.helpers.TransitionMetrics;
4141
import org.mitre.synthea.helpers.Utilities;
@@ -56,15 +56,16 @@
5656
/**
5757
* Generator creates a population by running the generic modules each timestep per Person.
5858
*/
59-
public class Generator implements RandomNumberGenerator {
59+
public class Generator {
6060

6161
/**
6262
* Unique ID for this instance of the Generator.
6363
* Even if the same settings are used multiple times, this ID should be unique.
6464
*/
6565
public final UUID id = UUID.randomUUID();
6666
public GeneratorOptions options;
67-
private Random random;
67+
private DefaultRandomNumberGenerator populationRandom;
68+
private DefaultRandomNumberGenerator clinicianRandom;
6869
public long timestep;
6970
public long stop;
7071
public long referenceTime;
@@ -218,7 +219,8 @@ private void init() {
218219
CDWExporter.getInstance().setKeyStart((stateIndex * 1_000_000) + 1);
219220
}
220221

221-
this.random = new Random(options.seed);
222+
this.populationRandom = new DefaultRandomNumberGenerator(options.seed);
223+
this.clinicianRandom = new DefaultRandomNumberGenerator(options.clinicianSeed);
222224
this.timestep = Long.parseLong(Config.get("generate.timestep"));
223225
this.stop = options.endTime;
224226
this.referenceTime = options.referenceTime;
@@ -262,7 +264,7 @@ private void init() {
262264
}
263265

264266
// initialize hospitals
265-
Provider.loadProviders(location, options.clinicianSeed);
267+
Provider.loadProviders(location, this.clinicianRandom);
266268
// Initialize Payers
267269
Payer.loadPayers(location);
268270
// ensure modules load early
@@ -366,7 +368,7 @@ public void run() {
366368
// Generate patients up to the specified population size.
367369
for (int i = 0; i < this.options.population; i++) {
368370
final int index = i;
369-
final long seed = this.random.nextLong();
371+
final long seed = this.populationRandom.randLong();
370372
threadPool.submit(() -> generatePerson(index, seed));
371373
}
372374
}
@@ -398,6 +400,8 @@ public void run() {
398400

399401
System.out.printf("Records: total=%d, alive=%d, dead=%d\n", totalGeneratedPopulation.get(),
400402
stats.get("alive").get(), stats.get("dead").get());
403+
System.out.printf("RNG=%d\n", this.populationRandom.getCount());
404+
System.out.printf("Clinician RNG=%d\n", this.clinicianRandom.getCount());
401405

402406
if (this.metrics != null) {
403407
metrics.printStats(totalGeneratedPopulation.get(), Module.getModules(getModulePredicate()));
@@ -440,6 +444,7 @@ public List<FixedRecordGroup> importFixedPatientDemographicsFile() {
440444
* @param index Target index in the whole set of people to generate
441445
* @return generated Person
442446
*/
447+
@Deprecated
443448
public Person generatePerson(int index) {
444449
// System.currentTimeMillis is not unique enough
445450
long personSeed = UUID.randomUUID().getMostSignificantBits() & Long.MAX_VALUE;
@@ -461,16 +466,15 @@ public Person generatePerson(int index) {
461466
*/
462467
public Person generatePerson(int index, long personSeed) {
463468

464-
Person person = null;
469+
Person person = new Person(personSeed);
465470

466471
try {
467472
int tryNumber = 0; // Number of tries to create these demographics
468-
Random randomForDemographics = new Random(personSeed);
469473

470-
Map<String, Object> demoAttributes = randomDemographics(randomForDemographics);
474+
Map<String, Object> demoAttributes = randomDemographics(person);
471475
if (this.recordGroups != null) {
472476
// Pick fixed demographics if a fixed demographics record file is used.
473-
demoAttributes = pickFixedDemographics(index, random);
477+
demoAttributes = pickFixedDemographics(index, person);
474478
}
475479

476480
boolean patientMeetsCriteria;
@@ -509,7 +513,7 @@ public Person generatePerson(int index, long personSeed) {
509513
// when we want to export this patient, but keep trying to produce one meeting criteria
510514
if (!check.exportAnyway()) {
511515
// rotate the seed so the next attempt gets a consistent but different one
512-
personSeed = randomForDemographics.nextLong();
516+
personSeed = person.randLong();
513517
continue;
514518
// skip the other stuff if the patient doesn't meet our goals
515519
// note that this skips ahead to the while check
@@ -521,19 +525,19 @@ public Person generatePerson(int index, long personSeed) {
521525

522526
if (!isAlive) {
523527
// rotate the seed so the next attempt gets a consistent but different one
524-
personSeed = randomForDemographics.nextLong();
528+
personSeed = person.randLong();
525529

526530
// if we've tried and failed > 10 times to generate someone over age 90
527531
// and the options allow for ages as low as 85
528532
// reduce the age to increase the likelihood of success
529533
if (tryNumber > 10 && (int)person.attributes.get(TARGET_AGE) > 90
530534
&& (!options.ageSpecified || options.minAge <= 85)) {
531535
// pick a new target age between 85 and 90
532-
int newTargetAge = randomForDemographics.nextInt(5) + 85;
536+
int newTargetAge = person.randInt(5) + 85;
533537
// the final age bracket is 85-110, but our patients rarely break 100
534538
// so reducing a target age to 85-90 shouldn't affect numbers too much
535539
demoAttributes.put(TARGET_AGE, newTargetAge);
536-
long birthdate = birthdateFromTargetAge(newTargetAge, randomForDemographics);
540+
long birthdate = birthdateFromTargetAge(newTargetAge, person);
537541
demoAttributes.put(Person.BIRTHDATE, birthdate);
538542
}
539543
}
@@ -705,7 +709,7 @@ public void updatePerson(Person person) {
705709
* @param random The random number generator to use.
706710
* @return demographics
707711
*/
708-
public Map<String, Object> randomDemographics(Random random) {
712+
public Map<String, Object> randomDemographics(RandomNumberGenerator random) {
709713
Demographics city = location.randomCity(random);
710714
Map<String, Object> demoAttributes = pickDemographics(random, city);
711715
return demoAttributes;
@@ -722,11 +726,12 @@ private synchronized void writeToConsole(Person person, int index, long time, bo
722726
// this is synchronized to ensure all lines for a single person are always printed
723727
// consecutively
724728
String deceased = isAlive ? "" : "DECEASED";
725-
System.out.format("%d -- %s (%d y/o %s) %s, %s %s\n", index + 1,
729+
System.out.format("%d -- %s (%d y/o %s) %s, %s %s (%d)\n", index + 1,
726730
person.attributes.get(Person.NAME), person.ageInYears(time),
727731
person.attributes.get(Person.GENDER),
728732
person.attributes.get(Person.CITY), person.attributes.get(Person.STATE),
729-
deceased);
733+
deceased,
734+
person.getCount());
730735

731736
if (this.logLevel.equals("detailed")) {
732737
System.out.println("ATTRIBUTES");
@@ -750,7 +755,7 @@ private synchronized void writeToConsole(Person person, int index, long time, bo
750755
* @param city The city to base the demographics off of.
751756
* @return the person's picked demographics.
752757
*/
753-
private Map<String, Object> pickDemographics(Random random, Demographics city) {
758+
private Map<String, Object> pickDemographics(RandomNumberGenerator random, Demographics city) {
754759
// Output map of the generated demographc data.
755760
Map<String, Object> demographicsOutput = new HashMap<>();
756761

@@ -794,7 +799,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
794799
double povertyRatio = city.povertyRatio(income);
795800
demographicsOutput.put(Person.POVERTY_RATIO, povertyRatio);
796801

797-
double occupation = random.nextDouble();
802+
double occupation = random.rand();
798803
demographicsOutput.put(Person.OCCUPATION_LEVEL, occupation);
799804

800805
double sesScore = city.socioeconomicScore(incomeLevel, educationLevel, occupation);
@@ -809,7 +814,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
809814
int targetAge;
810815
if (options.ageSpecified) {
811816
targetAge =
812-
(int) (options.minAge + ((options.maxAge - options.minAge) * random.nextDouble()));
817+
(int) (options.minAge + ((options.maxAge - options.minAge) * random.rand()));
813818
} else {
814819
targetAge = city.pickAge(random);
815820
}
@@ -827,7 +832,7 @@ private Map<String, Object> pickDemographics(Random random, Demographics city) {
827832
* @param index The index to use.
828833
* @param random Random object.
829834
*/
830-
private Map<String, Object> pickFixedDemographics(int index, Random random) {
835+
private Map<String, Object> pickFixedDemographics(int index, RandomNumberGenerator random) {
831836

832837
// Get the first FixedRecord from the current RecordGroup
833838
FixedRecordGroup recordGroup = this.recordGroups.get(index);
@@ -861,11 +866,11 @@ private Map<String, Object> pickFixedDemographics(int index, Random random) {
861866
* @param random A random object.
862867
* @return
863868
*/
864-
private long birthdateFromTargetAge(long targetAge, Random random) {
869+
private long birthdateFromTargetAge(long targetAge, RandomNumberGenerator random) {
865870
long earliestBirthdate = referenceTime - TimeUnit.DAYS.toMillis((targetAge + 1) * 365L + 1);
866871
long latestBirthdate = referenceTime - TimeUnit.DAYS.toMillis(targetAge * 365L);
867872
return
868-
(long) (earliestBirthdate + ((latestBirthdate - earliestBirthdate) * random.nextDouble()));
873+
(long) (earliestBirthdate + ((latestBirthdate - earliestBirthdate) * random.rand()));
869874
}
870875

871876
/**
@@ -908,52 +913,10 @@ private Predicate<String> getModulePredicate() {
908913
}
909914

910915
/**
911-
* Returns a random double.
912-
*/
913-
public double rand() {
914-
return random.nextDouble();
915-
}
916-
917-
/**
918-
* Returns a random boolean.
919-
*/
920-
public boolean randBoolean() {
921-
return random.nextBoolean();
922-
}
923-
924-
/**
925-
* Returns a random integer.
926-
*/
927-
public int randInt() {
928-
return random.nextInt();
929-
}
930-
931-
/**
932-
* Returns a random integer in the given bound.
916+
* Get the seeded random number generator used by this Generator.
917+
* @return the random number generator.
933918
*/
934-
public int randInt(int bound) {
935-
return random.nextInt(bound);
919+
public RandomNumberGenerator getRandomizer() {
920+
return this.populationRandom;
936921
}
937-
938-
/**
939-
* Returns a double from a normal distribution.
940-
*/
941-
public double randGaussian() {
942-
return random.nextGaussian();
943-
}
944-
945-
/**
946-
* Return a random long.
947-
*/
948-
public long randLong() {
949-
return random.nextLong();
950-
}
951-
952-
/**
953-
* Return a random UUID.
954-
*/
955-
public UUID randUUID() {
956-
return new UUID(randLong(), randLong());
957-
}
958-
959922
}

src/main/java/org/mitre/synthea/export/Exporter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,20 +434,20 @@ public static void runPostCompletionExports(Generator generator, ExporterRuntime
434434

435435
// Before we force bulk data to be off...
436436
try {
437-
FhirGroupExporterR4.exportAndSave(generator, generator.stop);
437+
FhirGroupExporterR4.exportAndSave(generator.getRandomizer(), generator.stop);
438438
} catch (Exception e) {
439439
e.printStackTrace();
440440
}
441441

442442
Config.set("exporter.fhir.bulk_data", "false");
443443
try {
444-
HospitalExporterR4.export(generator, generator.stop);
444+
HospitalExporterR4.export(generator.getRandomizer(), generator.stop);
445445
} catch (Exception e) {
446446
e.printStackTrace();
447447
}
448448

449449
try {
450-
FhirPractitionerExporterR4.export(generator, generator.stop);
450+
FhirPractitionerExporterR4.export(generator.getRandomizer(), generator.stop);
451451
} catch (Exception e) {
452452
e.printStackTrace();
453453
}

src/main/java/org/mitre/synthea/export/FhirDstu2.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ private static Entry basicInfo(Person person, Bundle bundle, long stopTime) {
478478

479479
String generatedBySynthea = "Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
480480
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
481-
+ " Person seed: " + person.seed
481+
+ " Person seed: " + person.getSeed()
482482
+ " Population seed: " + person.populationSeed;
483483

484484
patientResource.setText(new NarrativeDt(

src/main/java/org/mitre/synthea/export/FhirR4.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ private static BundleEntryComponent basicInfo(Person person, Bundle bundle, long
623623
String generatedBySynthea =
624624
"Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
625625
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
626-
+ " Person seed: " + person.seed
626+
+ " Person seed: " + person.getSeed()
627627
+ " Population seed: " + person.populationSeed;
628628

629629
patientResource.setText(new Narrative().setStatus(NarrativeStatus.GENERATED)

src/main/java/org/mitre/synthea/export/FhirStu3.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ private static BundleEntryComponent basicInfo(Person person, Bundle bundle, long
572572

573573
String generatedBySynthea = "Generated by <a href=\"https://github.com/synthetichealth/synthea\">Synthea</a>."
574574
+ "Version identifier: " + Utilities.SYNTHEA_VERSION + " . "
575-
+ " Person seed: " + person.seed
575+
+ " Person seed: " + person.getSeed()
576576
+ " Population seed: " + person.populationSeed;
577577

578578
patientResource.setText(new Narrative().setStatus(NarrativeStatus.GENERATED)

src/main/java/org/mitre/synthea/export/JSONExporter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public PersonSerializer(boolean includeModuleHistory) {
7676
@Override
7777
public JsonElement serialize(Person src, Type typeOfSrc, JsonSerializationContext context) {
7878
JsonObject personOut = new JsonObject();
79-
personOut.add("seed", new JsonPrimitive(src.seed));
79+
personOut.add("seed", new JsonPrimitive(src.getSeed()));
8080
personOut.add("lastUpdated", new JsonPrimitive(src.lastUpdated));
8181
personOut.add("coverage", context.serialize(src.coverage));
8282
JsonObject attributes = new JsonObject();

src/main/java/org/mitre/synthea/export/ValueSetCodeResolver.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ private Code resolveCode(@Nullable Code code) {
158158
return null;
159159
}
160160
return code.valueSet != null
161-
? RandomCodeGenerator.getCode(code.valueSet, person.seed, code)
161+
? RandomCodeGenerator.getCode(code.valueSet, person.getSeed(), code)
162162
: code;
163163
}
164164

0 commit comments

Comments
 (0)