From a2cfda73199cf6b90f2e47cb31e03f71a1f8a701 Mon Sep 17 00:00:00 2001
From: Dirk Toewe
Date: Wed, 30 Jan 2019 21:46:14 +0100
Subject: [PATCH 1/6] tf.linalg.bandPart
---
src/ops/linalg_ops.ts | 92 ++++++++++++++++++++++++++++-
src/ops/linalg_ops_test.ts | 118 +++++++++++++++++++++++++++++++++++++
2 files changed, 209 insertions(+), 1 deletion(-)
diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts
index 7a15b94e16..5336fd6b16 100644
--- a/src/ops/linalg_ops.ts
+++ b/src/ops/linalg_ops.ts
@@ -22,13 +22,17 @@
import {ENV} from '../environment';
import {dispose} from '../globals';
import {Tensor, Tensor1D, Tensor2D} from '../tensor';
+import {convertToTensor} from '../tensor_util_env';
+import {TensorLike} from '../types';
import {assert} from '../util';
import {eye, squeeze, stack, unstack} from './array_ops';
+import {sub} from './binary_ops';
import {split} from './concat_split';
+import {logicalAnd, where} from './logical_ops';
import {norm} from './norm';
import {op} from './operation';
import {sum} from './reduction_ops';
-import {tensor2d} from './tensor_ops';
+import {range, scalar, tensor2d, zeros} from './tensor_ops';
/**
* Gram-Schmidt orthogonalization.
@@ -260,5 +264,91 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
}) as [Tensor2D, Tensor2D];
}
+/**
+ * Copies a tensor of matrices, setting everything outside a central band
+ * in each matrix to zero.
+ *
+ * ```js
+ * >>> const a = tf.tensor2d([[11, 12, 13, 14],
+ * ... [21, 22, 23, 24],
+ * ... [31, 32, 33, 34],
+ * ... [41, 42, 43, 44]]);
+ * >>> tf.linalg.bandPart(a,0,2).print();
+ * [[11, 12, 13, 0],
+ * [ 0, 22, 23, 24],
+ * [ 0, 0, 33, 34],
+ * [ 0, 0, 0, 44]]
+ *
+ * >>> tf.linalg.bandPart(a,1,-1).print();
+ * [[11, 12, 13, 14],
+ * [21, 22, 23, 24],
+ * [ 0, 32, 33, 34],
+ * [ 0, 0, 43, 44]]
+ * ```
+ *
+ * @param a Tensor of matrices from which the band part is extracted.
+ * @param numLower The number of subdiagonal lines to be copied.
+ * If set to `-1`, all entries below the diagonal are
+ * copied.
+ * @param numUpper The number of superdiagonal lines to be copied.
+ * If set to `-1`, all entries above the diagonal are
+ * copied.
+ */
+/**
+ * @doc {heading:'Operations',
+ * subheading:'Linear Algebra',
+ * namespace:'linalg'}
+ */
+function bandPart_(
+ a: T|TensorLike, numLower: number, numUpper: number
+): T
+{
+ if( numLower%1 !== 0 ){
+ throw new Error(`bandPart(): numLower=${numLower} not an integer.`);
+ }
+ if( numUpper%1 !== 0 ){
+ throw new Error(`bandPart(): numUpper=${numUpper} not an integer.`);
+ }
+
+ return ENV.engine.tidy( () => {
+ const $a = convertToTensor(a,'a','bandPart');
+ a = undefined;
+
+ if( $a.rank < 2 ) {
+ throw new Error(`bandPart(): a.rank = ${$a.rank} < 2.`);
+ }
+
+ const shape = $a.shape,
+ [M,N] = $a.shape.slice(-2);
+
+ if( !(numLower <= M) ) {
+ throw new Error(`bandPart() check failed: numLower <= #rows.` );
+ }
+ if( !(numUpper <= N) ) {
+ throw new Error(`bandPart() check failed: numUpper <= #columns.`);
+ }
+
+ if( numLower < 0 ) { numLower = M; }
+ if( numUpper < 0 ) { numUpper = N; }
+
+ const i = range(0,M, 1, 'int32').reshape([-1,1]),
+ j = range(0,N, 1, 'int32');
+
+ const inBand = logicalAnd(
+ sub(i,j).lessEqual( scalar(numLower,'int32') ),
+ sub(j,i).lessEqual( scalar(numUpper,'int32') )
+ );
+
+ const zero = zeros([M,N], $a.dtype);
+
+ return stack(
+ unstack( $a.reshape([-1,M,N]) ).map(
+ mat => where(inBand, mat, zero)
+ )
+ ).reshape(shape) as T;
+ });
+}
+
export const gramSchmidt = op({gramSchmidt_});
+export const bandPart = op({bandPart_});
export const qr = op({qr_});
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index bfbb5ef62b..a69af58cd3 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -241,3 +241,121 @@ describeWithFlags('qr', ALL_ENVS, () => {
expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/);
});
});
+
+describeWithFlags('bandPart', ALL_ENVS, () => {
+ const la = tf.linalg;
+
+ // FIXME: shouldn't 1*x be lossless?
+ // It's even in the IEEE spec somewhere...
+ // Yet this fails on Travis with `expectArraysEqual`...
+ const expectArraysEqual = expectArraysClose;
+
+ it('works for 3x4 example', () => {
+ const a = tf.tensor2d([
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9,10,11,12]
+ ]);
+ expectArraysEqual(
+ la.bandPart(a,0,0),
+ tf.tensor2d([[1, 0, 0, 0],
+ [0, 6, 0, 0],
+ [0, 0,11, 0]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,0,1),
+ tf.tensor2d([[1, 2, 0, 0],
+ [0, 6, 7, 0],
+ [0, 0,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,0,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [0, 6, 7, 8],
+ [0, 0,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,0,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [0, 6, 7, 8],
+ [0, 0,11,12]])
+ );
+ for( const numUpper of [3,4,-1,-2] ) {
+ expectArraysEqual(
+ la.bandPart(a,0,numUpper),
+ tf.tensor2d([[1, 2, 3, 4],
+ [0, 6, 7, 8],
+ [0, 0,11,12]])
+ );
+ }
+
+ expectArraysEqual(
+ la.bandPart(a,1,0),
+ tf.tensor2d([[1, 0, 0, 0],
+ [5, 6, 0, 0],
+ [0,10,11, 0]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,1,1),
+ tf.tensor2d([[1, 2, 0, 0],
+ [5, 6, 7, 0],
+ [0,10,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,1,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [5, 6, 7, 8],
+ [0,10,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,1,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [5, 6, 7, 8],
+ [0,10,11,12]])
+ );
+ for( const numUpper of [3,4,-1,-2] ) {
+ expectArraysEqual(
+ la.bandPart(a,1,numUpper),
+ tf.tensor2d([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [0,10,11,12]])
+ );
+ }
+
+ for( const numLower of [2,3,-1,-2])
+ {
+ expectArraysEqual(
+ la.bandPart(a,numLower,0),
+ tf.tensor2d([[1, 0, 0, 0],
+ [5, 6, 0, 0],
+ [9,10,11, 0]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,numLower,1),
+ tf.tensor2d([[1, 2, 0, 0],
+ [5, 6, 7, 0],
+ [9,10,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,numLower,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [5, 6, 7, 8],
+ [9,10,11,12]])
+ );
+ expectArraysEqual(
+ la.bandPart(a,numLower,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [5, 6, 7, 8],
+ [9,10,11,12]])
+ );
+ for( const numUpper of [3,4,-1,-2] ) {
+ expectArraysEqual(
+ la.bandPart(a,numLower,numUpper),
+ tf.tensor2d([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9,10,11,12]])
+ );
+ }
+ }
+ });
+});
From 6e57e5cafba1db8be9e26965ad2a975c64ee8c91 Mon Sep 17 00:00:00 2001
From: Dirk T
Date: Thu, 31 Jan 2019 07:35:53 +0100
Subject: [PATCH 2/6] Remove workaround (previous GPU precision issues)
---
src/ops/linalg_ops_test.ts | 5 -----
1 file changed, 5 deletions(-)
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index a69af58cd3..4980bae843 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -245,11 +245,6 @@ describeWithFlags('qr', ALL_ENVS, () => {
describeWithFlags('bandPart', ALL_ENVS, () => {
const la = tf.linalg;
- // FIXME: shouldn't 1*x be lossless?
- // It's even in the IEEE spec somewhere...
- // Yet this fails on Travis with `expectArraysEqual`...
- const expectArraysEqual = expectArraysClose;
-
it('works for 3x4 example', () => {
const a = tf.tensor2d([
[1, 2, 3, 4],
From 0264f4dfac03d436dec436285063b760929518af Mon Sep 17 00:00:00 2001
From: Dirk T
Date: Thu, 31 Jan 2019 10:09:20 +0100
Subject: [PATCH 3/6] Update linalg_ops_test.ts
---
src/ops/linalg_ops_test.ts | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index 4980bae843..75749c47ba 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -18,7 +18,7 @@
import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {Tensor1D, Tensor2D} from '../tensor';
-import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util';
+import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util';
import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops';
From 8a9620a60f24084d00e0837c8a2e444d9be98e3a Mon Sep 17 00:00:00 2001
From: Dirk Toewe
Date: Thu, 31 Jan 2019 19:35:52 +0100
Subject: [PATCH 4/6] Made bandPart test float16-aware
---
src/ops/linalg_ops_test.ts | 39 +++++++++++++++++++++++---------------
1 file changed, 24 insertions(+), 15 deletions(-)
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index 75749c47ba..9ef2874a87 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -245,38 +245,47 @@ describeWithFlags('qr', ALL_ENVS, () => {
describeWithFlags('bandPart', ALL_ENVS, () => {
const la = tf.linalg;
+ const expectArrayEq = (() => {
+ switch( tf.ENV.backend.floatPrecision() )
+ {
+ default: return expectArraysClose;
+ case 32:
+ case 64: return expectArraysEqual;
+ }
+ })();
+
it('works for 3x4 example', () => {
const a = tf.tensor2d([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]
]);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,0,0),
tf.tensor2d([[1, 0, 0, 0],
[0, 6, 0, 0],
[0, 0,11, 0]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,0,1),
tf.tensor2d([[1, 2, 0, 0],
[0, 6, 7, 0],
[0, 0,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,0,2),
tf.tensor2d([[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,0,2),
tf.tensor2d([[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,0,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[0, 6, 7, 8],
@@ -284,32 +293,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => {
);
}
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,1,0),
tf.tensor2d([[1, 0, 0, 0],
[5, 6, 0, 0],
[0,10,11, 0]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,1,1),
tf.tensor2d([[1, 2, 0, 0],
[5, 6, 7, 0],
[0,10,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[0,10,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[0,10,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,1,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[5, 6, 7, 8],
@@ -319,32 +328,32 @@ describeWithFlags('bandPart', ALL_ENVS, () => {
for( const numLower of [2,3,-1,-2])
{
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,numLower,0),
tf.tensor2d([[1, 0, 0, 0],
[5, 6, 0, 0],
[9,10,11, 0]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,numLower,1),
tf.tensor2d([[1, 2, 0, 0],
[5, 6, 7, 0],
[9,10,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,numLower,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]])
);
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,numLower,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
- expectArraysEqual(
+ expectArrayEq(
la.bandPart(a,numLower,numUpper),
tf.tensor2d([[1, 2, 3, 4],
[5, 6, 7, 8],
From 9fcbfd7b92d06682cfdc4808e0803b7b24cf4963 Mon Sep 17 00:00:00 2001
From: Dirk Toewe
Date: Sun, 3 Feb 2019 10:45:14 +0100
Subject: [PATCH 5/6] Test precision now depending on tested ENVS.
---
src/ops/linalg_ops_test.ts | 174 ++++++++++++++++++-------------------
1 file changed, 85 insertions(+), 89 deletions(-)
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index 9ef2874a87..9da088202f 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -18,7 +18,7 @@
import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {Tensor1D, Tensor2D} from '../tensor';
-import {ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util';
+import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS} from '../test_util';
import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops';
@@ -242,124 +242,120 @@ describeWithFlags('qr', ALL_ENVS, () => {
});
});
-describeWithFlags('bandPart', ALL_ENVS, () => {
- const la = tf.linalg;
+for( const ENV of [CPU_ENVS, WEBGL_ENVS] )
+{
+ describeWithFlags('bandPart', ALL_ENVS, () => {
+ const la = tf.linalg;
- const expectArrayEq = (() => {
- switch( tf.ENV.backend.floatPrecision() )
- {
- default: return expectArraysClose;
- case 32:
- case 64: return expectArraysEqual;
- }
- })();
+ const expectArrayEq = Object.is(ENV, CPU_ENVS)
+ ? expectArraysEqual
+ : expectArraysClose;
- it('works for 3x4 example', () => {
- const a = tf.tensor2d([
- [1, 2, 3, 4],
- [5, 6, 7, 8],
- [9,10,11,12]
- ]);
- expectArrayEq(
- la.bandPart(a,0,0),
- tf.tensor2d([[1, 0, 0, 0],
- [0, 6, 0, 0],
- [0, 0,11, 0]])
- );
- expectArrayEq(
- la.bandPart(a,0,1),
- tf.tensor2d([[1, 2, 0, 0],
- [0, 6, 7, 0],
- [0, 0,11,12]])
- );
- expectArrayEq(
- la.bandPart(a,0,2),
- tf.tensor2d([[1, 2, 3, 0],
- [0, 6, 7, 8],
- [0, 0,11,12]])
- );
- expectArrayEq(
- la.bandPart(a,0,2),
- tf.tensor2d([[1, 2, 3, 0],
- [0, 6, 7, 8],
- [0, 0,11,12]])
- );
- for( const numUpper of [3,4,-1,-2] ) {
+ it('works for 3x4 example', () => {
+ const a = tf.tensor2d([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9,10,11,12]]);
+ expectArrayEq(
+ la.bandPart(a,0,0),
+ tf.tensor2d([[1, 0, 0, 0],
+ [0, 6, 0, 0],
+ [0, 0,11, 0]])
+ );
+ expectArrayEq(
+ la.bandPart(a,0,1),
+ tf.tensor2d([[1, 2, 0, 0],
+ [0, 6, 7, 0],
+ [0, 0,11,12]])
+ );
expectArrayEq(
- la.bandPart(a,0,numUpper),
- tf.tensor2d([[1, 2, 3, 4],
+ la.bandPart(a,0,2),
+ tf.tensor2d([[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]])
);
- }
-
- expectArrayEq(
- la.bandPart(a,1,0),
- tf.tensor2d([[1, 0, 0, 0],
- [5, 6, 0, 0],
- [0,10,11, 0]])
- );
- expectArrayEq(
- la.bandPart(a,1,1),
- tf.tensor2d([[1, 2, 0, 0],
- [5, 6, 7, 0],
- [0,10,11,12]])
- );
- expectArrayEq(
- la.bandPart(a,1,2),
- tf.tensor2d([[1, 2, 3, 0],
- [5, 6, 7, 8],
- [0,10,11,12]])
- );
- expectArrayEq(
- la.bandPart(a,1,2),
- tf.tensor2d([[1, 2, 3, 0],
- [5, 6, 7, 8],
- [0,10,11,12]])
- );
- for( const numUpper of [3,4,-1,-2] ) {
expectArrayEq(
- la.bandPart(a,1,numUpper),
- tf.tensor2d([[1, 2, 3, 4],
- [5, 6, 7, 8],
- [0,10,11,12]])
+ la.bandPart(a,0,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [0, 6, 7, 8],
+ [0, 0,11,12]])
);
- }
+ for( const numUpper of [3,4,-1,-2] ) {
+ expectArrayEq(
+ la.bandPart(a,0,numUpper),
+ tf.tensor2d([[1, 2, 3, 4],
+ [0, 6, 7, 8],
+ [0, 0,11,12]])
+ );
+ }
- for( const numLower of [2,3,-1,-2])
- {
expectArrayEq(
- la.bandPart(a,numLower,0),
+ la.bandPart(a,1,0),
tf.tensor2d([[1, 0, 0, 0],
[5, 6, 0, 0],
- [9,10,11, 0]])
+ [0,10,11, 0]])
);
expectArrayEq(
- la.bandPart(a,numLower,1),
+ la.bandPart(a,1,1),
tf.tensor2d([[1, 2, 0, 0],
[5, 6, 7, 0],
- [9,10,11,12]])
+ [0,10,11,12]])
);
expectArrayEq(
- la.bandPart(a,numLower,2),
+ la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
- [9,10,11,12]])
+ [0,10,11,12]])
);
expectArrayEq(
- la.bandPart(a,numLower,2),
+ la.bandPart(a,1,2),
tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
- [9,10,11,12]])
+ [0,10,11,12]])
);
for( const numUpper of [3,4,-1,-2] ) {
expectArrayEq(
- la.bandPart(a,numLower,numUpper),
+ la.bandPart(a,1,numUpper),
tf.tensor2d([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [0,10,11,12]])
+ );
+ }
+
+ for( const numLower of [2,3,-1,-2])
+ {
+ expectArrayEq(
+ la.bandPart(a,numLower,0),
+ tf.tensor2d([[1, 0, 0, 0],
+ [5, 6, 0, 0],
+ [9,10,11, 0]])
+ );
+ expectArrayEq(
+ la.bandPart(a,numLower,1),
+ tf.tensor2d([[1, 2, 0, 0],
+ [5, 6, 7, 0],
+ [9,10,11,12]])
+ );
+ expectArrayEq(
+ la.bandPart(a,numLower,2),
+ tf.tensor2d([[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]])
);
+ expectArrayEq(
+ la.bandPart(a,numLower,2),
+ tf.tensor2d([[1, 2, 3, 0],
+ [5, 6, 7, 8],
+ [9,10,11,12]])
+ );
+ for( const numUpper of [3,4,-1,-2] ) {
+ expectArrayEq(
+ la.bandPart(a,numLower,numUpper),
+ tf.tensor2d([[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9,10,11,12]])
+ );
+ }
}
- }
+ });
});
-});
+}
From 8c798e604ebf208e79e53a7b995f93d3b1440bbe Mon Sep 17 00:00:00 2001
From: Dirk Toewe
Date: Sun, 3 Feb 2019 11:12:10 +0100
Subject: [PATCH 6/6] Fixed ENV used with describeWithFlags.
---
src/ops/linalg_ops_test.ts | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts
index 9da088202f..f6e731d7c0 100644
--- a/src/ops/linalg_ops_test.ts
+++ b/src/ops/linalg_ops_test.ts
@@ -244,12 +244,12 @@ describeWithFlags('qr', ALL_ENVS, () => {
for( const ENV of [CPU_ENVS, WEBGL_ENVS] )
{
- describeWithFlags('bandPart', ALL_ENVS, () => {
- const la = tf.linalg;
+ const expectArrayEq = Object.is(ENV, CPU_ENVS)
+ ? expectArraysEqual
+ : expectArraysClose;
- const expectArrayEq = Object.is(ENV, CPU_ENVS)
- ? expectArraysEqual
- : expectArraysClose;
+ describeWithFlags('bandPart', ENV, () => {
+ const la = tf.linalg;
it('works for 3x4 example', () => {
const a = tf.tensor2d([[1, 2, 3, 4],