Skip to content

Commit f178454

Browse files
Mesilencekiliutongxuan
authored andcommitted
[Docs] Update documentation description of GroupEmbedding.
Signed-off-by: JunqiHu <[email protected]>
1 parent 200c896 commit f178454

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

Diff for: docs/docs_en/Group-Embedding.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,20 @@ GroupEmbedding provides two levels of API.The one is `tf.nn.group_embedding_look
2222
def group_embedding_lookup_sparse(params,
2323
sp_ids,
2424
combiners,
25-
partition_strategy="mod",
2625
sp_weights=None,
26+
partition_strategy="mod",
27+
is_sequence=False,
28+
params_num_per_group=sys.maxsize,
2729
name=None):
2830
```
2931

3032
- `params` : List, This parameter could receive one or more EmbeddingVariables or native Tensorflow Variable.
3133
- `sp_ids` : List | Tuple , SparseTensor sp_ids ​​is the ID used for EmbeddingLookup, the length must be consistent with params.
3234
- `combiners` : List | Tuple,The pooling method of embedding values.Currently support `mean` and `sum`.
33-
- `partition_strategy` : str,Currently not supported.
3435
- `sp_weights` : List | Typle the weight of sp_ids values.
36+
- `partition_strategy` : str,Currently not supported.
37+
- `is_sequence` : bool, Op would return Tensor shape of [B, T, D] if True
38+
- `params_num_per_group` : int, This parameter indicates the number of Variables inside each Op. The default setting is the maximum value. The default value is suitable for GPU scenarios; when using the CPU, it is recommended to set the smaller the better.
3539
- `name` : str group name
3640

3741
**group_embedding_lookup**

Diff for: docs/docs_zh/Group-Embedding.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@ Group Embedding功能支持同时对多个EmbeddingVariable 聚合查询,将
2626
def group_embedding_lookup_sparse(params,
2727
sp_ids,
2828
combiners,
29-
partition_strategy="mod",
3029
sp_weights=None,
30+
partition_strategy="mod",
31+
is_sequence=False,
32+
params_num_per_group=sys.maxsize,
3133
name=None):
3234
```
3335

3436
- `params` : List, 该参数可以接收一个或者多个EmbeddingVariable或者是原生Tensorflow Variable
3537
- `sp_ids` : List | Tuple , SparseTensor ,values是用于查找的ID 长度必须和params保持一致
3638
- `combiners` : List | Tuple 查找完得到的embedding tensor聚合的方式,支持 `mean``sum`
37-
- `partition_strategy` : str 目前暂时不支持
3839
- `sp_weights` : List | Typle sp_ids 的 values 的权重。
40+
- `partition_strategy` : str 目前暂时不支持
41+
- `is_sequence` : bool 如果设置为True,则返回的embedding形状为(B, T, D)。
42+
- `params_num_per_group` : int 该参数表示每个Op内部Variable的个数,默认设置为最大值,默认值适用于GPU的场景;当使用CPU的时候建议设置越小越好
3943
- `name` : str group的名称
4044

4145
**group_embedding_lookup**
@@ -55,7 +59,7 @@ def group_embedding_lookup(params,
5559
**group_embedding_column_scope**
5660

5761
```python
58-
def group_embedding_column_scope(name=None):
62+
def group_embedding_column_scope(name=None, params_num_per_group=sys.maxsize):
5963
```
6064

6165
- `name` : scope的名称

Diff for: tensorflow/python/ops/embedding_ops.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1608,10 +1608,11 @@ def group_embedding_lookup_sparse(params,
16081608
is_sequence: bool
16091609
return list of `Tensor` of shape `[batch_size, D]` when is False
16101610
return list of `Tensor` of shape `[batch_size, T, D]` when is True
1611-
sub_group_size: int
1612-
A string specifying the grouping strategy of group embedding op.["gpu", "cpu"]
1613-
are supported. Setting "gpu" will group all embeddings to maximize the GPU utilization.
1614-
"cpu" will split embeddings so as to maximize intra parallelism.
1611+
params_num_per_group: int
1612+
The number of params in GroupEmbedding op.Function will schedule len(params) // params_num_per_group + 1
1613+
GroupEmbedding Op. Default setting would launch one Op containing all params which is suitable for GPU scenarios
1614+
to maximize the GPU utilization.On the contrast, you could set value to 1 when Op
1615+
is placed on CPU so as to maximize inter parallelism.
16151616
name: The operations name
16161617
Returns
16171618
-------

0 commit comments

Comments
 (0)