diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 7b4d87fe..00fc2118 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -264,6 +264,13 @@ |Not mapped|ifft| |Not mapped|rfft| +## Operations - Strings + +|Tensorflow Op Name|Tensorflow.js Op Name| +|---|---| +|DecodeBase64|decodeBase64| +|EncodeBase64|encodeBase64| + ## Tensors - Transformations |Tensorflow Op Name|Tensorflow.js Op Name| diff --git a/src/operations/executors/string_executor.ts b/src/operations/executors/string_executor.ts new file mode 100644 index 00000000..2ae03e1f --- /dev/null +++ b/src/operations/executors/string_executor.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tfc from '@tensorflow/tfjs-core'; + +import {NamedTensorsMap} from '../../data/types'; +import {ExecutionContext} from '../../executor/execution_context'; +import {InternalOpExecutor, Node} from '../types'; + +import {getParamValue} from './utils'; + +export let executeOp: InternalOpExecutor = + (node: Node, tensorMap: NamedTensorsMap, + context: ExecutionContext): tfc.Tensor[] => { + switch (node.op) { + case 'DecodeBase64': { + const input = + getParamValue('str', node, tensorMap, context) as tfc.Tensor; + return [tfc.decodeBase64(input)]; + } + case 'EncodeBase64': { + const input = + getParamValue('str', node, tensorMap, context) as tfc.Tensor; + const pad = getParamValue('pad', node, tensorMap, context) as boolean; + return [tfc.encodeBase64(input, pad)]; + } + default: + throw TypeError(`Node type ${node.op} is not implemented`); + } + }; + +export const CATEGORY = 'string'; diff --git a/src/operations/executors/string_executor_test.ts b/src/operations/executors/string_executor_test.ts new file mode 100644 index 00000000..ff0ae480 --- /dev/null +++ b/src/operations/executors/string_executor_test.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import * as tfc from '@tensorflow/tfjs-core'; + +import {ExecutionContext} from '../../executor/execution_context'; +import {Node} from '../types'; + +import {executeOp} from './string_executor'; +import {createBoolAttr, createTensorAttr} from './test_helper'; + +describe('string', () => { + let node: Node; + const input1 = [tfc.tensor(['a'], [1], 'string')]; + const context = new ExecutionContext({}, {}); + + beforeEach(() => { + node = { + name: 'test', + op: '', + category: 'string', + inputNames: ['input1'], + inputs: [], + inputParams: {str: createTensorAttr(0)}, + attrParams: {}, + children: [] + }; + }); + + describe('executeOp', () => { + describe('DecodeBase64', () => { + it('should call tfc.decodeBase64', () => { + spyOn(tfc, 'decodeBase64'); + node.op = 'DecodeBase64'; + executeOp(node, {input1}, context); + expect(tfc.decodeBase64).toHaveBeenCalledWith(input1[0]); + }); + }); + describe('EncodeBase64', () => { + it('should call tfc.encodeBase64', () => { + spyOn(tfc, 'encodeBase64'); + node.op = 'EncodeBase64'; + node.attrParams.pad = createBoolAttr(true); + executeOp(node, {input1}, context); + expect(tfc.encodeBase64).toHaveBeenCalledWith(input1[0], true); + }); + }); + }); +}); diff --git a/src/operations/op_list/string.ts b/src/operations/op_list/string.ts new file mode 100644 index 00000000..e891af94 --- /dev/null +++ b/src/operations/op_list/string.ts @@ -0,0 +1,32 @@ +import {OpMapper} from '../types'; + +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export const json: OpMapper[] = [ + { + 'tfOpName': 'DecodeBase64', + 'category': 'string', + 'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}] + }, + { + 'tfOpName': 'EncodeBase64', + 'category': 'string', + 'inputs': [{'start': 0, 'name': 'input', 'type': 'tensor'}], + 'attrs': [{'tfName': 'pad', 'name': 'pad', 'type': 'bool'}] + } +]; diff --git a/src/operations/types.ts b/src/operations/types.ts index 9109f0af..8df6154a 100644 --- a/src/operations/types.ts +++ b/src/operations/types.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -25,7 +25,8 @@ export type ParamType = 'number'|'string'|'string[]'|'number[]'|'bool'|'bool[]'| export type Category = 'arithmetic'|'basic_math'|'control'|'convolution'|'custom'|'dynamic'| 'evaluation'|'image'|'creation'|'graph'|'logical'|'matrices'| - 'normalization'|'reduction'|'slice_join'|'spectral'|'transformation'; + 'normalization'|'reduction'|'slice_join'|'spectral'|'string'| + 'transformation'; // For mapping input or attributes of NodeDef into TensorFlow.js op param. export declare interface ParamMapper {