diff --git a/animated-transformer/src/lib/gtensor/gtensor.spec.ts b/animated-transformer/src/lib/gtensor/gtensor.spec.ts index cf15b70..4e4bb39 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.spec.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.spec.ts @@ -905,4 +905,125 @@ describe('gtensor', () => { [0, 0, 0], ]); }); + + fit('where', async () => { + const g1 = new gtensor.GTensor( + tf.tensor([ + [ + [1, 2], + [3, 4], + [5, 6], + ], + [ + [1, 2], + [3, 4], + [5, 6], + ], + ]), + ['example', 'pos', 'repSize'], + ); + + const g2 = new gtensor.GTensor( + tf.tensor([ + [0, 0], + [0, 0], + [0, 0], + ]), + ['pos', 'repSize'], + ); + + const condition = new gtensor.GTensor( + tf.tensor2d( + [ + [1, 0], + [0, 1], + [1, 0], + ], + [3, 2], + 'bool', + ), + ['pos', 'repSize'], + ); + + const g1WhereCondition = g1.where(condition, g2); + + expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']); + tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [ + [ + [1, 0], + [0, 4], + [5, 0], + ], // example = 1 + [ + [1, 0], + [0, 4], + [5, 0], + ], + ]); + }); + + fit('where no broadcast over g2', async () => { + const g1 = new gtensor.GTensor( + tf.tensor([ + [ + [1, 2], + [3, 4], + [5, 6], + ], + [ + [1, 2], + [3, 4], + [5, 6], + ], + ]), + ['example', 'pos', 'repSize'], + ); + + const g2 = new gtensor.GTensor( + tf.tensor( + [ + [ + [0, 0], + [0, 0], + [0, 0], + ], // example = 1 + [ + [0, 0], + [0, 0], + [0, 0], + ], + ], // example = 2 + ), + ['example', 'pos', 'repSize'], + ); + + const condition = new gtensor.GTensor( + tf.tensor2d( + [ + [1, 0], + [0, 1], + [1, 0], + ], + [3, 2], + 'bool', + ), + ['pos', 'repSize'], + ); + + const g1WhereCondition = g1.where(condition, g2); + + expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']); + tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [ + [ + [1, 0], + [0, 4], + [5, 0], + ], // example = 1 + [ + [1, 0], + [0, 4], + [5, 0], + ], + ]); + }); }); diff --git a/animated-transformer/src/lib/gtensor/gtensor.ts b/animated-transformer/src/lib/gtensor/gtensor.ts index e67e24e..2030075 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.ts @@ -840,6 +840,37 @@ export class GTensor { this.dimNames, ); } + + /* Returns the elements, from this of the gtensor or g2 depending on the condition. + If the condition is true, select from the gtensor, otherwise select from g2. + if gtensor.dims != g2.dims g2 is broadcasted to this.dimensions! + if gtensor.dims != cond.dims condition is broadcasted to this.dimensions! */ + public where( + condition: GTensor, + g2: GTensor, + ): GTensor { + // Verify that D and G2 are smaller than G or return an error + if (condition.dimNames.length > this.dimNames.length) { + throw new ValueError('The rank of condition cannot be higher than the rank of this tensor'); + } + if (g2.dimNames.length > this.dimNames.length) { + throw new ValueError('The rank of g2 cannot be higher than the rank of this tensor'); + } + // Broadcast G2 to this tensor's dims + const g2big = g2.broadcastToCombinedShape(this); + const g1big = this.broadcastToCombinedShape(g2); + const g2bigLikeG1 = g2big.transposeLike(g1big); + + // Broadcast condition to this tensor's dims + const conditionBig = condition.broadcastToCombinedShape(this); + const g1bigC = this.broadcastToCombinedShape(condition); + const conditionBigLikeG1 = conditionBig.transposeLike(g1bigC); + + return new GTensor( + this.tensor.where(conditionBigLikeG1.tensor, g2bigLikeG1.tensor), + this.dimNames, + ); + } } export class GVariable extends GTensor { @@ -964,13 +995,17 @@ export function makeRange( * - dtype : The type of an element in the resulting tensor. Defaults to 'float32' * // TODO add optianal broadcastTo dimensions/GTensor * */ -export function makeTriangularMatrix( +export function makeTriangularMatrix< + N1 extends string, + N2 extends string, + T extends string | number, +>( size: number, d1Name: N1, d2Name: N2, lowerLeftValue: T, upperRightValue: T, - dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32' + dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32', ): GTensor { // Create a range tensor for row indices const rowIndices = tf.range(0, size, 1, 'int32');