Skip to content

Commit 2bea301

Browse files
authored
Add duplicate argument in sample method (#893)
1 parent 8a8c8fa commit 2bea301

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

lib/util/matrix.js

+18-8
Original file line numberDiff line numberDiff line change
@@ -652,18 +652,28 @@ export default class Matrix {
652652
* Returns a matrix that sampled along the axis.
653653
* @param {number} n Sampled size
654654
* @param {number} [axis] Axis to be sampled
655+
* @param {number} [duplicate] Allow duplicate index or not
655656
* @returns {[Matrix, number[]]} Sampled matrix and its original indexes
656657
*/
657-
sample(n, axis = 0) {
658+
sample(n, axis = 0, duplicate = false) {
658659
const k = this.sizes[axis]
659660
const idx = []
660-
for (let i = 0; i < n; i++) {
661-
idx.push(Math.floor(Math.random() * (k - i)))
662-
}
663-
for (let i = n - 1; i >= 0; i--) {
664-
for (let j = n - 1; j > i; j--) {
665-
if (idx[i] <= idx[j]) {
666-
idx[j]++
661+
if (duplicate) {
662+
for (let i = 0; i < n; i++) {
663+
idx.push(Math.floor(Math.random() * k))
664+
}
665+
} else {
666+
if (n > k) {
667+
throw new MatrixException('Invalid sampled size.')
668+
}
669+
for (let i = 0; i < n; i++) {
670+
idx.push(Math.floor(Math.random() * (k - i)))
671+
}
672+
for (let i = n - 1; i >= 0; i--) {
673+
for (let j = n - 1; j > i; j--) {
674+
if (idx[i] <= idx[j]) {
675+
idx[j]++
676+
}
667677
}
668678
}
669679
}

tests/lib/util/matrix.test.js

+51
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,29 @@ describe('Matrix', () => {
989989
expect(expidx).toEqual(idx)
990990
})
991991

992+
test.each([undefined, 0])('row(%p) duplicate index', axis => {
993+
const n = 6
994+
const org = Matrix.randn(3, 5)
995+
const [mat, idx] = org.sample(n, axis, true)
996+
expect(idx).toHaveLength(n)
997+
998+
const expidx = []
999+
for (let k = 0; k < n; k++) {
1000+
for (let i = 0; i < org.rows; i++) {
1001+
let flg = true
1002+
for (let j = 0; j < org.cols; j++) {
1003+
flg &= mat.at(k, j) === org.at(i, j)
1004+
}
1005+
if (flg) {
1006+
expidx.push(i)
1007+
break
1008+
}
1009+
}
1010+
}
1011+
expect(expidx).toHaveLength(n)
1012+
expect(expidx).toEqual(idx)
1013+
})
1014+
9921015
test('col index', () => {
9931016
const n = 3
9941017
const org = Matrix.randn(10, 5)
@@ -1012,6 +1035,34 @@ describe('Matrix', () => {
10121035
expect(expidx).toEqual(idx)
10131036
})
10141037

1038+
test('col duplicate index', () => {
1039+
const n = 6
1040+
const org = Matrix.randn(3, 5)
1041+
const [mat, idx] = org.sample(n, 1, true)
1042+
expect(idx).toHaveLength(n)
1043+
1044+
const expidx = []
1045+
for (let k = 0; k < n; k++) {
1046+
for (let j = 0; j < org.cols; j++) {
1047+
let flg = true
1048+
for (let i = 0; i < org.rows; i++) {
1049+
flg &= mat.at(i, k) === org.at(i, j)
1050+
}
1051+
if (flg) {
1052+
expidx.push(j)
1053+
break
1054+
}
1055+
}
1056+
}
1057+
expect(expidx).toHaveLength(n)
1058+
expect(expidx).toEqual(idx)
1059+
})
1060+
1061+
test('fail invalid sampled size %p', () => {
1062+
const mat = Matrix.randn(5, 10)
1063+
expect(() => mat.sample(6, 0)).toThrow('Invalid sampled size.')
1064+
})
1065+
10151066
test.each([-1, 2])('fail invalid axis %p', axis => {
10161067
const mat = Matrix.randn(5, 10)
10171068
expect(() => mat.sample(4, axis)).toThrow('Invalid axis.')

0 commit comments

Comments
 (0)