Skip to content

Commit e773bad

Browse files
committed
Add shared field to adanet.Subnetwork.
This deprecates, replaces, and is more flexible than `persisted_tensors`. TODO: Replace `persisted_tensors` with `shared` in examples and tutorials. PiperOrigin-RevId: 223382387
1 parent 694ab99 commit e773bad

File tree

5 files changed

+89
-33
lines changed

5 files changed

+89
-33
lines changed

Diff for: RELEASE.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
# Current version (0.4.0-dev)
1717
* Under development.
18+
* Add `shared` field to `adanet.Subnetwork` to deprecate, replace, and be more flexible than `persisted_tensors`.
1819
* Officially support multi-head learning with or without dict labels.
1920
* Rebuild the ensemble across iterations in Python without a frozen graph. This allows users to share more than `Tensors` between iterations including Python primitives, objects, and lambdas for greater flexibility. Eliminating reliance on a `MetaGraphDef` proto also eliminates I/O allowing for faster training, and better future-proofing.
2021
* Allow users to pass custom eval metrics when constructing an `adanet.Estimator`.

Diff for: adanet/core/estimator_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def build_subnetwork(self,
120120
last_layer=last_layer if self._return_penultimate_layer else logits,
121121
logits=logits,
122122
complexity=3,
123-
persisted_tensors=persisted_tensors)
123+
persisted_tensors=persisted_tensors,
124+
shared=persisted_tensors)
124125

125126
def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels,
126127
iteration_step, summary, previous_ensemble):

Diff for: adanet/core/subnetwork/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ py_library(
1818
py_library(
1919
name = "generator",
2020
srcs = ["generator.py"],
21+
deps = [
22+
],
2123
)
2224

2325
py_test(

Diff for: adanet/core/subnetwork/generator.py

+34-20
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import abc
2323
import collections
2424

25+
from tensorflow.python.util import deprecation
26+
2527

2628
def _validate_nested_persisted_tensors(persisted_tensors):
2729
"""Raises a ValueError when a nested dict is empty in persisted_tensors."""
@@ -37,15 +39,23 @@ def _validate_nested_persisted_tensors(persisted_tensors):
3739
class Subnetwork(
3840
collections.namedtuple(
3941
"Subnetwork",
40-
["last_layer", "logits", "complexity", "persisted_tensors"])):
42+
["last_layer", "logits", "complexity", "persisted_tensors", "shared"])):
4143
"""An AdaNet subnetwork.
4244
4345
In the AdaNet paper, an `adanet.Subnetwork` is are called a 'subnetwork',
4446
and indicated by 'h'. A collection of weighted subnetworks form an AdaNet
4547
ensemble.
4648
"""
4749

48-
def __new__(cls, last_layer, logits, complexity, persisted_tensors):
50+
@deprecation.deprecated_args(
51+
None, "`persisted_tensors` is deprecated, please use `shared` instead.",
52+
"persisted_tensors")
53+
def __new__(cls,
54+
last_layer,
55+
logits,
56+
complexity,
57+
persisted_tensors=None,
58+
shared=None):
4959
"""Creates a validated `Subnetwork` instance.
5060
5161
Args:
@@ -58,23 +68,25 @@ def __new__(cls, last_layer, logits, complexity, persisted_tensors):
5868
This field is represented by 'h' in the AdaNet paper.
5969
logits: `Tensor` logits or dict of string to `Tensor` logits (for
6070
multi-head) for training the subnetwork. NOTE: These logits are not used
61-
in the ensemble's outputs if the mixture weight type is `MATRIX`,
62-
instead AdaNet learns its own logits (mixture weights) from the
63-
subnetwork's `last_layers` with complexity regularization. The logits
64-
are used in the ensemble only when the mixture weights type is `SCALAR`
65-
or `VECTOR`. Even though the logits are not used in the ensemble in some
66-
cases, they should always be supplied as adanet uses the logits to train
67-
the subnetworks.
71+
in the ensemble's outputs if the mixture weight type is `MATRIX`,
72+
instead AdaNet learns its own logits (mixture weights) from the
73+
subnetwork's `last_layers` with complexity regularization. The logits
74+
are used in the ensemble only when the mixture weights type is
75+
`SCALAR` or `VECTOR`. Even though the logits are not used in the
76+
ensemble in some cases, they should always be supplied as adanet uses
77+
the logits to train the subnetworks.
6878
complexity: A scalar `Tensor` representing the complexity of the
6979
subnetwork's architecture. It is used for choosing the best subnetwork
7080
at each iteration, and for regularizing the weighted outputs of more
7181
complex subnetworks.
72-
persisted_tensors: Nested dictionary of string to `Tensor` to persist
73-
across iterations. At the end of an iteration, the `Tensors` will be
74-
available to subnetworks in the next iterations, whereas others that are
75-
not part of the `Subnetwork` will be pruned. This allows later
76-
`Subnetworks` to dynamically build upon arbitrary `Tensors` from
77-
previous `Subnetworks`.
82+
persisted_tensors: DEPRECATED: see `shared`. Optional nested dictionary of
83+
string to `Tensor` to persist across iterations. At the end of an
84+
iteration, the `Tensors` will be available to subnetworks in the next
85+
iterations, whereas others that are not part of the `Subnetwork` will be
86+
pruned. This allows later `Subnetworks` to dynamically build upon
87+
arbitrary `Tensors` from previous `Subnetworks`.
88+
shared: Optional Python object, primitive, or function to share with
89+
subnetworks within the same iteration or in future iterations.
7890
7991
Returns:
8092
A validated `Subnetwork` object.
@@ -85,7 +97,7 @@ def __new__(cls, last_layer, logits, complexity, persisted_tensors):
8597
ValueError: If logits is a dict but last_layer is not.
8698
ValueError: If last_layer is a dict but logits is not.
8799
ValueError: If complexity is None.
88-
ValueError: If persisted_tensors is not a dictionary.
100+
ValueError: If persisted_tensors is present but not a dictionary.
89101
ValueError: If persisted_tensors contains an empty nested dictionary.
90102
"""
91103

@@ -99,15 +111,17 @@ def __new__(cls, last_layer, logits, complexity, persisted_tensors):
99111
raise ValueError("if last_layer is a dict logits must also be a dict")
100112
if complexity is None:
101113
raise ValueError("complexity not provided")
102-
if not isinstance(persisted_tensors, dict):
103-
raise ValueError("persisted_tensors must be a dict")
104-
_validate_nested_persisted_tensors(persisted_tensors)
114+
if persisted_tensors is not None:
115+
if not isinstance(persisted_tensors, dict):
116+
raise ValueError("persisted_tensors must be a dict")
117+
_validate_nested_persisted_tensors(persisted_tensors)
105118
return super(Subnetwork, cls).__new__(
106119
cls,
107120
last_layer=last_layer,
108121
logits=logits,
109122
complexity=complexity,
110-
persisted_tensors=persisted_tensors)
123+
persisted_tensors=persisted_tensors,
124+
shared=shared)
111125

112126

113127
class Builder(object):

Diff for: adanet/core/subnetwork/generator_test.py

+50-12
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,24 @@ def build_mixture_weights_train_op(self, loss, var_list, logits, labels,
6363
class SubnetworkTest(parameterized.TestCase, tf.test.TestCase):
6464

6565
@parameterized.named_parameters({
66+
"testcase_name": "no_persisted_tensors_nor_shared",
67+
"last_layer": dummy_tensor(),
68+
"logits": dummy_tensor(),
69+
"complexity": dummy_tensor(),
70+
}, {
6671
"testcase_name": "empty_persisted_tensors",
6772
"last_layer": dummy_tensor(),
6873
"logits": dummy_tensor(),
6974
"complexity": dummy_tensor(),
7075
"persisted_tensors": {},
7176
}, {
7277
"testcase_name": "dict_logits_and_last_layer",
73-
"last_layer": {"head1": dummy_tensor()},
74-
"logits": {"head1": dummy_tensor()},
78+
"last_layer": {
79+
"head1": dummy_tensor()
80+
},
81+
"logits": {
82+
"head1": dummy_tensor()
83+
},
7584
"complexity": dummy_tensor(),
7685
"persisted_tensors": {},
7786
}, {
@@ -96,14 +105,45 @@ class SubnetworkTest(parameterized.TestCase, tf.test.TestCase):
96105
},
97106
},
98107
},
108+
}, {
109+
"testcase_name": "shared_primitive",
110+
"last_layer": dummy_tensor(),
111+
"logits": dummy_tensor(),
112+
"complexity": dummy_tensor(),
113+
"shared": 1,
114+
}, {
115+
"testcase_name": "shared_dict",
116+
"last_layer": dummy_tensor(),
117+
"logits": dummy_tensor(),
118+
"complexity": dummy_tensor(),
119+
"shared": {},
120+
}, {
121+
"testcase_name": "shared_lambda",
122+
"last_layer": dummy_tensor(),
123+
"logits": dummy_tensor(),
124+
"complexity": dummy_tensor(),
125+
"shared": lambda x: x,
126+
}, {
127+
"testcase_name": "shared_object",
128+
"last_layer": dummy_tensor(),
129+
"logits": dummy_tensor(),
130+
"complexity": dummy_tensor(),
131+
"shared": dummy_tensor(),
99132
})
100-
def test_new(self, last_layer, logits, complexity, persisted_tensors):
133+
def test_new(self,
134+
last_layer,
135+
logits,
136+
complexity,
137+
persisted_tensors=None,
138+
shared=None):
101139
with self.test_session():
102-
got = Subnetwork(last_layer, logits, complexity, persisted_tensors)
140+
got = Subnetwork(last_layer, logits, complexity, persisted_tensors,
141+
shared)
103142
self.assertEqual(got.last_layer, last_layer)
104143
self.assertEqual(got.logits, logits)
105144
self.assertEqual(got.complexity, complexity)
106145
self.assertEqual(got.persisted_tensors, persisted_tensors)
146+
self.assertEqual(got.shared, shared)
107147

108148
@parameterized.named_parameters({
109149
"testcase_name": "none_last_layer",
@@ -123,12 +163,6 @@ def test_new(self, last_layer, logits, complexity, persisted_tensors):
123163
"logits": dummy_tensor(),
124164
"complexity": None,
125165
"persisted_tensors": {},
126-
}, {
127-
"testcase_name": "none_persisted_tensors",
128-
"last_layer": dummy_tensor(),
129-
"logits": dummy_tensor(),
130-
"complexity": dummy_tensor(),
131-
"persisted_tensors": None,
132166
}, {
133167
"testcase_name": "empty_list_persisted_tensors",
134168
"last_layer": dummy_tensor(),
@@ -168,12 +202,16 @@ def test_new(self, last_layer, logits, complexity, persisted_tensors):
168202
}, {
169203
"testcase_name": "only_dict_logits",
170204
"last_layer": dummy_tensor(),
171-
"logits": {"head": dummy_tensor()},
205+
"logits": {
206+
"head": dummy_tensor()
207+
},
172208
"complexity": dummy_tensor(),
173209
"persisted_tensors": {},
174210
}, {
175211
"testcase_name": "only_dict_last_layer",
176-
"last_layer": {"head": dummy_tensor()},
212+
"last_layer": {
213+
"head": dummy_tensor()
214+
},
177215
"logits": dummy_tensor(),
178216
"complexity": dummy_tensor(),
179217
"persisted_tensors": {},

0 commit comments

Comments
 (0)