Skip to content

Commit 200c896

Browse files
Mesilencekiliutongxuan
authored andcommitted
[Embedding] Add group_strategy to control parallelism of group_embedding.
Signed-off-by: JunqiHu <[email protected]>
1 parent 6c03afc commit 200c896

File tree

6 files changed

+310
-208
lines changed

6 files changed

+310
-208
lines changed

Diff for: tensorflow/python/BUILD

+3-3
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ py_library(
968968
":type_spec",
969969
":util",
970970
":versions",
971-
":group_embedding_ops_utils",
971+
":group_embedding_types",
972972
"//tensorflow/core:protos_all_py",
973973
"//tensorflow/python/eager:context",
974974
"//tensorflow/python/eager:core",
@@ -985,8 +985,8 @@ py_library(
985985
)
986986

987987
py_library(
988-
name = "group_embedding_ops_utils",
989-
srcs = ["framework/group_embedding_ops_utils.py"],
988+
name = "group_embedding_types",
989+
srcs = ["framework/group_embedding_types.py"],
990990
srcs_version = "PY2AND3",
991991
deps = [
992992
":pywrap_tensorflow",

Diff for: tensorflow/python/feature_column/feature_column_v2.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
import contextlib
132132
import math
133133

134+
import sys
134135
import numpy as np
135136
import six
136137
import json
@@ -4235,21 +4236,21 @@ def _from_config(cls, config, custom_objects=None, columns_by_name=None):
42354236

42364237
@tf_export('feature_column.group_embedding_column_scope')
42374238
@contextlib.contextmanager
4238-
def group_embedding_column_scope(name=''):
4239+
def group_embedding_column_scope(name='', params_num_per_group=sys.maxsize):
42394240
global_group_embedding_scope = group_embedding_column._global_group_embedding_scope_list()
42404241
group_id = group_embedding_column._current_group_id()
42414242
if name == '':
42424243
name = "group_embedding_column_scope_{}".format(group_id)
42434244
group_id +=1
42444245
else:
42454246
name = "group_embedding_column_scope_{}".format(name)
4246-
fusion_embedding_scope = GroupEmbeddingScope(name)
4247+
fusion_embedding_scope = GroupEmbeddingScope(name, params_num_per_group)
42474248
global_group_embedding_scope.append(fusion_embedding_scope)
42484249
yield global_group_embedding_scope
42494250

42504251
class GroupEmbeddingScope(group_embedding_column.GroupEmbeddingScopeBase):
4251-
def __init__(self, name=None):
4252-
super(GroupEmbeddingScope, self).__init__(name=name)
4252+
def __init__(self, name=None, params_num_per_group=sys.maxsize):
4253+
super(GroupEmbeddingScope, self).__init__(name=name, params_num_per_group=params_num_per_group)
42534254

42544255
def add_column(self, embedding_column):
42554256
VALID_EMBEDDING_COLUMN_TYPES = (
@@ -4288,7 +4289,8 @@ def _get_dense_tensor(self, filter_ec, inputs,
42884289
embedding_weights.append(embedding_weight)
42894290

42904291
output_tensors.extend(embedding_ops.group_embedding_lookup_sparse(
4291-
embedding_weights, sp_ids, combiners, is_sequence=is_sequence))
4292+
embedding_weights, sp_ids, combiners,
4293+
is_sequence=is_sequence, params_num_per_group=self.params_num_per_group))
42924294
return output_tensors, sequence_lengths
42934295

42944296
class EmbeddingColumn(

Diff for: tensorflow/python/feature_column/group_embedding_column.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from __future__ import division
33
from __future__ import print_function
44

5+
import sys
6+
57
_global_fusion_embedding_scope = []
68
_group_id = 0
79
_group_embedding_tensor = dict()
@@ -40,8 +42,9 @@ def _current_group_id():
4042
return _group_id
4143

4244
class GroupEmbeddingScopeBase(object):
43-
def __init__(self, name=None):
45+
def __init__(self, name=None, params_num_per_group=sys.maxsize):
4446
self.name = name
47+
self.params_num_per_group = params_num_per_group
4548
self.embedding_columns = []
4649

4750
def add_column(self, embedding_column):

Diff for: tensorflow/python/framework/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensorflow.python.eager import context
2222
from tensorflow.python import _pywrap_tensor_float_32_execution
2323
from tensorflow.python.util.tf_export import tf_export
24-
from tensorflow.python.framework import group_embedding_ops_utils
24+
from tensorflow.python.framework import group_embedding_types
2525

2626
@tf_export('config.experimental.tensor_float_32_execution_enabled')
2727
def tensor_float_32_execution_enabled():
@@ -646,7 +646,7 @@ def enable_distributed_strategy(strategy="collective"):
646646
try:
647647
from sparse_operation_kit import experiment as sok
648648
sok.init()
649-
group_embedding_ops_utils.set_group_lookup_strategy(strategy)
649+
group_embedding_types.set_group_lookup_strategy(strategy)
650650
except:
651651
raise ImportError("While param `strategy` in enable_distributed_strategy"
652652
"is given `collective`, sok module initialize error,"

Diff for: tensorflow/python/framework/group_embedding_ops_utils.py renamed to tensorflow/python/framework/group_embedding_types.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,26 @@
2121
from enum import Enum, unique
2222

2323
@unique
24-
class STRATEGY(Enum):
24+
class DistStrategy(Enum):
2525
COLLECTIVE = "collective"
2626
DISTRIBUTED = "ps"
2727
LOCALIZED = "localized"
2828
UNKNOWN = "unknown"
2929

30-
_group_lookup_strategy = STRATEGY.LOCALIZED
30+
_group_lookup_strategy = DistStrategy.LOCALIZED
3131

3232
def set_group_lookup_strategy(strategy):
3333
def str_to_strategy(strategy):
3434
if strategy == "collective":
35-
return STRATEGY.COLLECTIVE
35+
return DistStrategy.COLLECTIVE
3636
elif strategy == "ps":
37-
return STRATEGY.DISTRIBUTED
37+
return DistStrategy.DISTRIBUTED
3838
elif strategy == "localized":
39-
return STRATEGY.LOCALIZED
39+
return DistStrategy.LOCALIZED
4040

4141
global _group_lookup_strategy
4242
_group_lookup_strategy = str_to_strategy(strategy)
4343

4444
def get_group_lookup_strategy():
4545
global _group_lookup_strategy
46-
return _group_lookup_strategy
46+
return _group_lookup_strategy

0 commit comments

Comments
 (0)