@@ -132,6 +132,42 @@ object DistributedRegularizedObjective {
132132 PartialPrimalStats (ax.toArray, cx, xx)
133133 }
134134
135+ /**
136+ * This method accumulates sufficient statistics (i.e. ax, cx and xx) from a dataset of PartialPrimalStats.
137+ * It returns an Array[(Int, Array[Double])] where the first entry corresponds to the index of the dual vector.
138+ * We always store cx and xx in the last two indices. For example, if the dual is of dimension 20, the maximum value
139+ * of the returned indices would be 21 (0 to 19 for the indices of the dual vector and 20 and 21 for cx and xx
140+ * respectively).
141+ *
142+ * @param primalStats
143+ * @param lambdaDim
144+ * @param numPartitions
145+ * @param sparkSession
146+ * @return
147+ */
148+ def accumulateSufficientStatistics (primalStats : Dataset [PartialPrimalStats ], lambdaDim : Int , numPartitions : Int )
149+ (implicit sparkSession : SparkSession ): Array [(Int , Array [Double ])] = {
150+ import sparkSession .implicits ._
151+ primalStats.mapPartitions { partitionIterator =>
152+ val acxxAgg = new Array [Double ](lambdaDim + 2 )
153+ partitionIterator.foreach { stats =>
154+ val ax = stats.costs
155+ var i = 0
156+ val axLen = ax.length
157+ while (i < axLen) {
158+ val (axIndex, axValue) = ax(i)
159+ acxxAgg(axIndex) += axValue
160+ i += 1
161+ }
162+ acxxAgg(lambdaDim) += stats.objective
163+ acxxAgg(lambdaDim + 1 ) += stats.xx
164+ }
165+ // partition array
166+ val x = ArrayAggregation .partitionArray(acxxAgg, numPartitions)
167+ x.iterator
168+ }.rdd.reduceByKey(ArrayAggregation .aggregateArrays(_, _)).collect()
169+ }
170+
135171 /**
136172 * Does aggregation in the following way:
137173 * 1. each data partition performs aggregation of the gradients into java Array (dense)
@@ -158,37 +194,43 @@ object DistributedRegularizedObjective {
158194 */
159195 def twoStepGradientAggregator (primalStats : Dataset [PartialPrimalStats ], lambdaDim : Int , numPartitions : Int )
160196 (implicit sparkSession : SparkSession ): PartialPrimalStats = {
161- import sparkSession .implicits ._
162- val aggregate = primalStats.mapPartitions { partitionIterator =>
163- val acxxAgg = new Array [Double ](lambdaDim + 2 )
164- partitionIterator.foreach { stats =>
165- val ax = stats.costs
166- var i = 0
167- while (i < ax.length) {
168- val (axIndex, axValue) = ax(i)
169- acxxAgg(axIndex) += axValue
170- i += 1
171- }
172- acxxAgg(lambdaDim) += stats.objective
173- acxxAgg(lambdaDim + 1 ) += stats.xx
174- }
175- // partition array
176- val x = ArrayAggregation .partitionArray(acxxAgg, numPartitions)
177- x.iterator
178- }.rdd.reduceByKey(ArrayAggregation .aggregateArrays(_, _)).collect()
179197
198+ val aggregatedStats = accumulateSufficientStatistics(primalStats, lambdaDim, numPartitions)
180199 val ax = new Array [Double ](lambdaDim)
181200 var cx = 0.0
182201 var xx = 0.0
183- aggregate.foreach { case (partition, subarray) =>
184- val (start, end) = ArrayAggregation .partitionBounds(lambdaDim + 2 , numPartitions, partition)
185- if (partition == numPartitions - 1 ) {
186- // special case for last partition, as it holds 'xx' and 'cx' in the last two positions
187- cx = subarray(subarray.length - 2 )
188- xx = subarray(subarray.length - 1 )
189- System .arraycopy(subarray, 0 , ax, start, subarray.length - 2 )
190- } else {
191- System .arraycopy(subarray, 0 , ax, start, subarray.length)
202+ val axLen = ax.length
203+ aggregatedStats.foreach { case (partition, subarray) =>
204+ val (startIndex, endIndex) = ArrayAggregation .partitionBounds(lambdaDim + 2 , numPartitions, partition)
205+ if (partition < (numPartitions - 2 ) || (partition == numPartitions - 2 && endIndex <= axLen)) {
206+ // aggregation of the ax values for different indices of the dual when we haven't reached the last two
207+ // partitions or we have reached the second last partition and it still contains dual values only (i.e. no cx)
208+ System .arraycopy(subarray, 0 , ax, startIndex, subarray.length)
209+ }
210+ else if (partition == (numPartitions - 2 )) {
211+ // once we hit the second last partition with endIndex > axLen, we definitely have cx in here
212+ cx = subarray(subarray.length - 1 )
213+ if (subarray.length > 1 ) {
214+ // along with some of the remaining duals
215+ System .arraycopy(subarray, 0 , ax, startIndex, subarray.length - 1 )
216+ }
217+ }
218+ else {
219+ // when we hit the last partition
220+ if (subarray.length == 1 ) {
221+ // and it contains only one element, it has to be xx
222+ xx = subarray(0 )
223+ }
224+ else {
225+ // if the last partition has more than one element, then the last two elements must be cx and xx respectively
226+ cx = subarray(subarray.length - 2 )
227+ xx = subarray(subarray.length - 1 )
228+ if (subarray.length > 2 ) {
229+ // if the last partition has more than two elements then it must contain few of the remaining duals and cx
230+ // and xx
231+ System .arraycopy(subarray, 0 , ax, startIndex, subarray.length - 2 )
232+ }
233+ }
192234 }
193235 }
194236 PartialPrimalStats (ax.zipWithIndex.map { case (v, i) => (i, v) }, cx, xx)
0 commit comments