Skip to content

Commit b2e6e95

Browse files
authored
Fix bug in twoStepGradientAggregator method. (#43)
1 parent 01d63ad commit b2e6e95

4 files changed

Lines changed: 116 additions & 42 deletions

File tree

dualip/src/main/scala/com/linkedin/dualip/objective/distributedobjective/DistributedRegularizedObjective.scala

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,42 @@ object DistributedRegularizedObjective {
132132
PartialPrimalStats(ax.toArray, cx, xx)
133133
}
134134

135+
/**
136+
* This method accumulates sufficient statistics (i.e. ax, cx and xx) from a dataset of PartialPrimalStats.
137+
* It returns an Array[(Int, Array[Double])] where the first entry corresponds to the index of the dual vector.
138+
* We always store cx and xx in the last two indices. For example, if the dual is of dimension 20, the maximum value
139+
* of the returned indices would be 21 (0 to 19 for the indices of the dual vector and 20 and 21 for cx and xx
140+
* respectively).
141+
*
142+
* @param primalStats
143+
* @param lambdaDim
144+
* @param numPartitions
145+
* @param sparkSession
146+
* @return
147+
*/
148+
def accumulateSufficientStatistics(primalStats: Dataset[PartialPrimalStats], lambdaDim: Int, numPartitions: Int)
149+
(implicit sparkSession: SparkSession): Array[(Int, Array[Double])] = {
150+
import sparkSession.implicits._
151+
primalStats.mapPartitions { partitionIterator =>
152+
val acxxAgg = new Array[Double](lambdaDim + 2)
153+
partitionIterator.foreach { stats =>
154+
val ax = stats.costs
155+
var i = 0
156+
val axLen = ax.length
157+
while (i < axLen) {
158+
val (axIndex, axValue) = ax(i)
159+
acxxAgg(axIndex) += axValue
160+
i += 1
161+
}
162+
acxxAgg(lambdaDim) += stats.objective
163+
acxxAgg(lambdaDim + 1) += stats.xx
164+
}
165+
// partition array
166+
val x = ArrayAggregation.partitionArray(acxxAgg, numPartitions)
167+
x.iterator
168+
}.rdd.reduceByKey(ArrayAggregation.aggregateArrays(_, _)).collect()
169+
}
170+
135171
/**
136172
* Does aggregation in the following way:
137173
* 1. each data partition performs aggregation of the gradients into java Array (dense)
@@ -158,37 +194,43 @@ object DistributedRegularizedObjective {
158194
*/
159195
def twoStepGradientAggregator(primalStats: Dataset[PartialPrimalStats], lambdaDim: Int, numPartitions: Int)
160196
(implicit sparkSession: SparkSession): PartialPrimalStats = {
161-
import sparkSession.implicits._
162-
val aggregate = primalStats.mapPartitions { partitionIterator =>
163-
val acxxAgg = new Array[Double](lambdaDim + 2)
164-
partitionIterator.foreach { stats =>
165-
val ax = stats.costs
166-
var i = 0
167-
while (i < ax.length) {
168-
val (axIndex, axValue) = ax(i)
169-
acxxAgg(axIndex) += axValue
170-
i += 1
171-
}
172-
acxxAgg(lambdaDim) += stats.objective
173-
acxxAgg(lambdaDim + 1) += stats.xx
174-
}
175-
// partition array
176-
val x = ArrayAggregation.partitionArray(acxxAgg, numPartitions)
177-
x.iterator
178-
}.rdd.reduceByKey(ArrayAggregation.aggregateArrays(_, _)).collect()
179197

198+
val aggregatedStats = accumulateSufficientStatistics(primalStats, lambdaDim, numPartitions)
180199
val ax = new Array[Double](lambdaDim)
181200
var cx = 0.0
182201
var xx = 0.0
183-
aggregate.foreach { case (partition, subarray) =>
184-
val (start, end) = ArrayAggregation.partitionBounds(lambdaDim + 2, numPartitions, partition)
185-
if (partition == numPartitions - 1) {
186-
// special case for last partition, as it holds 'xx' and 'cx' in the last two positions
187-
cx = subarray(subarray.length - 2)
188-
xx = subarray(subarray.length - 1)
189-
System.arraycopy(subarray, 0, ax, start, subarray.length - 2)
190-
} else {
191-
System.arraycopy(subarray, 0, ax, start, subarray.length)
202+
val axLen = ax.length
203+
aggregatedStats.foreach { case (partition, subarray) =>
204+
val (startIndex, endIndex) = ArrayAggregation.partitionBounds(lambdaDim + 2, numPartitions, partition)
205+
if (partition < (numPartitions - 2) || (partition == numPartitions - 2 && endIndex <= axLen)) {
206+
// aggregation of the ax values for different indices of the dual when we haven't reached the last two
207+
// partitions or we have reached the second last partition and it still contains dual values only (i.e. no cx)
208+
System.arraycopy(subarray, 0, ax, startIndex, subarray.length)
209+
}
210+
else if (partition == (numPartitions - 2)) {
211+
// once we hit the second last partition with endIndex > axLen, we definitely have cx in here
212+
cx = subarray(subarray.length - 1)
213+
if (subarray.length > 1) {
214+
// along with some of the remaining duals
215+
System.arraycopy(subarray, 0, ax, startIndex, subarray.length - 1)
216+
}
217+
}
218+
else {
219+
// when we hit the last partition
220+
if (subarray.length == 1) {
221+
// and it contains only one element, it has to be xx
222+
xx = subarray(0)
223+
}
224+
else {
225+
// if the last partition has more than one element, then the last two elements must be cx and xx respectively
226+
cx = subarray(subarray.length - 2)
227+
xx = subarray(subarray.length - 1)
228+
if (subarray.length > 2) {
229+
// if the last partition has more than two elements then it must contain few of the remaining duals and cx
230+
// and xx
231+
System.arraycopy(subarray, 0, ax, startIndex, subarray.length - 2)
232+
}
233+
}
192234
}
193235
}
194236
PartialPrimalStats(ax.zipWithIndex.map { case (v, i) => (i, v) }, cx, xx)

dualip/src/main/scala/com/linkedin/dualip/util/ArrayAggregation.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ object ArrayAggregation {
4949
* Method to find [start, end) positions in the array of a given partition.
5050
* Array is partitioned into roughly identical partitions, the length of a partition may differ by one.
5151
*
52-
* @param arrayLength
53-
* @param numPartitions
54-
* @param partition
52+
* @param arrayLength - length of the array whose slicing positions are being considered
53+
* @param numPartitions - the total number of partitions that the array of length arrayLength is stored in
54+
* @param partition - the index value for the partitions that can assume values from 0 to numPartitions - 1
5555
* @return
5656
*/
5757
def partitionBounds(arrayLength: Int, numPartitions: Int, partition: Int): (Int, Int) = {
@@ -61,13 +61,17 @@ object ArrayAggregation {
6161
// we will stack larger partitions in the beginning of the array
6262
val basePartitionSize = arrayLength / numPartitions
6363
val numLargerPartitions = arrayLength - basePartitionSize * numPartitions
64-
// second term accounts for larger partitions stacked prior to partition in question
65-
val startIndex = partition * basePartitionSize + math.min(partition, numLargerPartitions)
64+
6665
val partitionSize = if (partition < numLargerPartitions) {
6766
basePartitionSize + 1
6867
} else {
6968
basePartitionSize
7069
}
70+
val startIndex = if (partition < numLargerPartitions) {
71+
partition * partitionSize
72+
} else {
73+
numLargerPartitions + partition * partitionSize
74+
}
7175
val endIndex = startIndex + partitionSize
7276
(startIndex, endIndex)
7377
}

dualip/src/test/scala/com/linkedin/dualip/objective/distributedobjective/DistributedRegularizedObjectiveTest.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.linkedin.dualip.objective.distributedobjective
22

33
import com.linkedin.dualip.objective.PartialPrimalStats
4+
import com.linkedin.dualip.objective.distributedobjective.DistributedRegularizedObjective.accumulateSufficientStatistics
5+
import com.linkedin.dualip.util.ArrayAggregation.partitionBounds
46
import com.linkedin.spark.common.lib.TestUtils
57
import org.apache.spark.SparkConf
68
import org.apache.spark.sql.SparkSession
@@ -15,15 +17,38 @@ class DistributedRegularizedObjectiveTest {
1517
val expectedAx: Map[Int, Double] = partialGradientsTestData.flatMap(_.costs).groupBy(_._1).mapValues(_.map(_._2).sum)
1618
val expectedXx: Double = partialGradientsTestData.map(_.xx).sum
1719

20+
@Test
21+
def testAccumulateSufficientStatistics(): Unit = {
22+
implicit val spark: SparkSession = TestUtils.createSparkSession("testAccumulateSufficientStatistics")
23+
import spark.implicits._
24+
25+
val lambdaDim = 9
26+
val numPartitions = 4
27+
val primalStats = spark.createDataset(partialGradientsTestData)
28+
val aggregatedStats = accumulateSufficientStatistics(primalStats, lambdaDim, 4).toMap
29+
30+
(0 until numPartitions - 1).foreach { partitionNumber =>
31+
val (startIndex, endIndex) = partitionBounds(arrayLength = lambdaDim + 2, numPartitions = numPartitions,
32+
partition = partitionNumber)
33+
(startIndex until endIndex).zipWithIndex.foreach { case (arrayIndex, index) =>
34+
Assert.assertEquals(aggregatedStats(partitionNumber)(index), expectedAx(arrayIndex), 0.01)
35+
}
36+
}
37+
}
38+
1839
@Test
1940
def testTwoStepGradientAggregation(): Unit = {
20-
implicit val spark: SparkSession = TestUtils.createSparkSession()
41+
implicit val spark: SparkSession = TestUtils.createSparkSession("testTwoStepGradientAggregation")
2142
import spark.implicits._
22-
val ds = spark.createDataset(partialGradientsTestData).repartition(2)
23-
val aggPrimalStats = DistributedRegularizedObjective.twoStepGradientAggregator(ds, 9, 2)
24-
assertAlmostEqual(aggPrimalStats.costs.toMap, expectedAx)
25-
assertAlmostEqual(aggPrimalStats.objective, expectedCx)
26-
assertAlmostEqual(aggPrimalStats.xx, expectedXx)
43+
44+
Array(2, 5, 9).foreach { numPartitions =>
45+
print("number of partitions " + numPartitions + "\n")
46+
val ds = spark.createDataset(partialGradientsTestData).repartition(numPartitions)
47+
val aggPrimalStats = DistributedRegularizedObjective.twoStepGradientAggregator(ds, 9, numPartitions)
48+
assertAlmostEqual(aggPrimalStats.costs.toMap, expectedAx)
49+
assertAlmostEqual(aggPrimalStats.objective, expectedCx)
50+
assertAlmostEqual(aggPrimalStats.xx, expectedXx)
51+
}
2752
}
2853

2954
@Test

dualip/src/test/scala/com/linkedin/dualip/util/ArrayAggregationTest.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.linkedin.dualip.util
22

3+
import com.linkedin.dualip.objective.distributedobjective.DistributedRegularizedObjective.accumulateSufficientStatistics
34
import com.linkedin.spark.common.lib.TestUtils
45
import org.apache.spark.sql.{Dataset, SparkSession}
56
import org.testng.Assert
@@ -31,11 +32,13 @@ class ArrayAggregationTest {
3132
@Test
3233
def testPartitionBounds(): Unit = {
3334
// even split into partitions
34-
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 2, partition = 1), (4,8))
35+
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 2, partition = 1), (4, 8))
3536
// uneven split, larger partitions should be stacked first
36-
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 0), (0,3))
37-
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 1), (3,6))
38-
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 2), (6,8))
37+
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 0), (0, 3))
38+
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 1), (3, 6))
39+
Assert.assertEquals(partitionBounds(arrayLength = 8, numPartitions = 3, partition = 2), (6, 8))
40+
Assert.assertEquals(partitionBounds(arrayLength = 40, numPartitions = 20, partition = 10), (20, 22))
41+
Assert.assertEquals(partitionBounds(arrayLength = 40, numPartitions = 38, partition = 37), (39, 40))
3942
}
4043

4144
@Test(

0 commit comments

Comments
 (0)