Skip to content

Commit be1b8e8

Browse files
author
hgarrereyn
committed
Merge remote-tracking branch 'upstream/master' into core
2 parents 1df33ac + 736488d commit be1b8e8

18 files changed

+470
-123
lines changed

Diff for: RELEASE.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
check and prevent duplications. DO NOT rely on this behavior.
1818
* CsvTFXIO now allows skipping CSV headers (`set skip_header_lines`).
1919
* CsvTFXIO now requires `telemetry_descriptors` to construct.
20+
* Depends on `pyarrow>=0.17,<0.18`.
2021

2122
## Breaking changes
2223

Diff for: tfx_bsl/beam/run_inference.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,8 @@ def setup(self):
545545
def _setup_model(
546546
self, inference_spec_type: model_spec_pb2.InferenceSpecType
547547
):
548+
self._ai_platform_prediction_model_spec = (
549+
inference_spec_type.ai_platform_prediction_model_spec)
548550
self._api_client = None
549551

550552
project_id = (
@@ -589,10 +591,10 @@ def _make_request(self, body: Mapping[Text, List[Any]]) -> http.HttpRequest:
589591
return self._api_client.projects().predict(
590592
name=self._full_model_name, body=body)
591593

592-
@classmethod
593-
def _prepare_instances(
594-
cls, elements: List[ExampleType]
594+
def _prepare_instances_dict(
595+
self, elements: List[ExampleType]
595596
) -> Generator[Mapping[Text, Any], None, None]:
597+
"""Prepare instances by converting features to dictionary."""
596598
for example in elements:
597599
# TODO(b/151468119): support tf.train.SequenceExample
598600
if not isinstance(example, tf.train.Example):
@@ -604,13 +606,28 @@ def _prepare_instances(
604606
if attr_name is None:
605607
continue
606608
attr = getattr(feature, attr_name)
607-
values = cls._parse_feature_content(attr.value, attr_name,
608-
cls._sending_as_binary(input_name))
609+
values = self._parse_feature_content(
610+
attr.value, attr_name, self._sending_as_binary(input_name))
609611
# Flatten a sequence if its length is 1
610612
values = (values[0] if len(values) == 1 else values)
611613
instance[input_name] = values
612614
yield instance
613615

616+
def _prepare_instances_serialized(
617+
self, elements: List[ExampleType]
618+
) -> Generator[Mapping[Text, Text], None, None]:
619+
"""Prepare instances by base64 encoding serialized examples."""
620+
for example in elements:
621+
yield {'b64': base64.b64encode(example.SerializeToString()).decode()}
622+
623+
def _prepare_instances(
624+
self, elements: List[ExampleType]
625+
) -> Generator[Mapping[Text, Any], None, None]:
626+
if self._ai_platform_prediction_model_spec.use_serialization_config:
627+
return self._prepare_instances_serialized(elements)
628+
else:
629+
return self._prepare_instances_dict(elements)
630+
614631
@staticmethod
615632
def _sending_as_binary(input_name: Text) -> bool:
616633
"""Whether data should be sent as binary."""

Diff for: tfx_bsl/beam/run_inference_test.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# Standard __future__ imports
1919
from __future__ import print_function
2020

21+
import base64
2122
import json
2223
import os
2324
try:
@@ -603,9 +604,17 @@ def test_request_body_with_binary_data(self):
603604
feature { key: "z" value { float_list { value: [4.5, 5, 5.5] }}}
604605
}
605606
""", tf.train.Example())
606-
result = list(
607-
run_inference._RemotePredictDoFn._prepare_instances([example]))
608-
self.assertEqual([
607+
inference_spec_type = model_spec_pb2.InferenceSpecType(
608+
ai_platform_prediction_model_spec=model_spec_pb2
609+
.AIPlatformPredictionModelSpec(
610+
project_id='test_project',
611+
model_name='test_model',
612+
version_name='test_version'))
613+
remote_predict = run_inference._RemotePredictDoFn(
614+
None, fixed_inference_spec_type=inference_spec_type)
615+
remote_predict._setup_model(remote_predict._fixed_inference_spec_type)
616+
result = list(remote_predict._prepare_instances([example]))
617+
self.assertEqual(result, [
609618
{
610619
'x_bytes': {
611620
'b64': 'QVNhOGFzZGY='
@@ -614,7 +623,31 @@ def test_request_body_with_binary_data(self):
614623
'y': [1, 2],
615624
'z': [4.5, 5, 5.5]
616625
},
617-
], result)
626+
])
627+
628+
def test_request_serialized_example(self):
629+
example = text_format.Parse(
630+
"""
631+
features {
632+
feature { key: "x_bytes" value { bytes_list { value: ["ASa8asdf"] }}}
633+
feature { key: "x" value { bytes_list { value: "JLK7ljk3" }}}
634+
feature { key: "y" value { int64_list { value: [1, 2] }}}
635+
}
636+
""", tf.train.Example())
637+
inference_spec_type = model_spec_pb2.InferenceSpecType(
638+
ai_platform_prediction_model_spec=model_spec_pb2
639+
.AIPlatformPredictionModelSpec(
640+
project_id='test_project',
641+
model_name='test_model',
642+
version_name='test_version',
643+
use_serialization_config=True))
644+
remote_predict = run_inference._RemotePredictDoFn(
645+
None, fixed_inference_spec_type=inference_spec_type)
646+
remote_predict._setup_model(remote_predict._fixed_inference_spec_type)
647+
result = list(remote_predict._prepare_instances([example]))
648+
self.assertEqual(result, [{
649+
'b64': base64.b64encode(example.SerializeToString()).decode()
650+
}])
618651

619652

620653
class RunInferenceCoreTest(RunInferenceFixture):

Diff for: tfx_bsl/cc/sketches/BUILD

+21-20
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,6 @@ cc_test(
4040
],
4141
)
4242

43-
cc_library(
44-
name = "sketches_submodule",
45-
srcs = ["sketches_submodule.cc"],
46-
hdrs = ["sketches_submodule.h"],
47-
copts = [
48-
"-fexceptions",
49-
],
50-
features = ["-use_header_modules"],
51-
visibility = [
52-
"//tfx_bsl/cc:__pkg__",
53-
],
54-
deps = [
55-
":kmv_sketch",
56-
"//tfx_bsl/cc/pybind11:absl_casters",
57-
"//tfx_bsl/cc/pybind11:arrow_casters",
58-
"@com_google_absl//absl/strings",
59-
"@pybind11",
60-
],
61-
)
62-
6343
cc_library(
6444
name = "misragries_sketch",
6545
srcs = ["misragries_sketch.cc"],
@@ -88,3 +68,24 @@ cc_test(
8868
"@com_google_googletest//:gtest_main",
8969
],
9070
)
71+
72+
cc_library(
73+
name = "sketches_submodule",
74+
srcs = ["sketches_submodule.cc"],
75+
hdrs = ["sketches_submodule.h"],
76+
copts = [
77+
"-fexceptions",
78+
],
79+
features = ["-use_header_modules"],
80+
visibility = [
81+
"//tfx_bsl/cc:__pkg__",
82+
],
83+
deps = [
84+
":kmv_sketch",
85+
":misragries_sketch",
86+
"//tfx_bsl/cc/pybind11:absl_casters",
87+
"//tfx_bsl/cc/pybind11:arrow_casters",
88+
"@com_google_absl//absl/strings",
89+
"@pybind11",
90+
],
91+
)

Diff for: tfx_bsl/cc/sketches/kmv_sketch.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,16 @@ class KmvSketch {
4646
// KmvSketch is copyable and movable.
4747
KmvSketch(int num_buckets);
4848
~KmvSketch() = default;
49-
// Updates the sketch with an Arrow array of values. Supports numeric arrays
50-
// of all integral types, binary arrays, and string arrays.
49+
// Updates the sketch with an Arrow array of values.
5150
Status AddValues(const arrow::Array& array);
52-
// Merges the sketch from another object into this sketch. Returns error if
53-
// the other sketch has a different number of buckets than this sketch.
51+
// Merges another KMV sketch into this sketch. Returns error if the other
52+
// sketch has a different number of buckets than this sketch.
5453
Status Merge(KmvSketch& other);
5554
// Estimates the number of distinct elements.
5655
uint64_t Estimate() const;
5756
// Serializes the sketch into a string.
5857
std::string Serialize() const;
59-
// Converts the encoded sketch into a KmvSketch object.
58+
// Deserializes the string to a KmvSketch object.
6059
static KmvSketch Deserialize(absl::string_view encoded);
6160

6261
private:

Diff for: tfx_bsl/cc/sketches/misragries_sketch.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ class UpdateItemCountsVisitor : public arrow::ArrayVisitor {
131131
if (array.IsNull(i)) {
132132
continue;
133133
}
134-
const auto item = array.GetView(i);
134+
const auto value = array.GetView(i);
135+
const auto item = absl::string_view(value.data(), value.size());
135136
const float weight = weights_->Value(i);
136137
InsertItem(item, weight);
137138
}

Diff for: tfx_bsl/cc/sketches/misragries_sketch.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ class MisraGriesSketch {
6565
MisraGriesSketch(int num_buckets);
6666
// This class is copyable and movable.
6767
~MisraGriesSketch() = default;
68-
// Add an array of items.
68+
// Adds an array of items.
6969
Status AddValues(const arrow::Array& items);
70-
// Add an array of items with their associated weights. Raises an error if the
71-
// weights are not a FloatArray.
70+
// Adds an array of items with their associated weights. Raises an error if
71+
// the weights are not a FloatArray.
7272
Status AddValues(const arrow::Array& items, const arrow::Array& weights);
7373
// Merges another MisraGriesSketch into this sketch. Raises an error if the
7474
// sketches do not have the same number of buckets.
@@ -77,9 +77,9 @@ class MisraGriesSketch {
7777
std::vector<std::pair<std::string, double>> GetCounts() const;
7878
// Creates a struct array <values, counts> of the top-k items.
7979
Status Estimate(std::shared_ptr<arrow::Array>* values_and_counts_array) const;
80-
// Serializes proto of sketch.
80+
// Serializes the sketch into a string.
8181
std::string Serialize() const;
82-
// Parses encoded MisraGries proto.
82+
// Deserializes the string to a MisraGries object.
8383
static MisraGriesSketch Deserialize(absl::string_view encoded);
8484
// Gets delta_.
8585
double GetDelta() const;

Diff for: tfx_bsl/cc/sketches/misragries_sketch_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ TEST(MisraGriesSketchTest, AddWeightedZipfDistribution) {
387387
for (auto const& item : counts) {
388388
double true_count = true_counts[std::string(item.first)] * trial;
389389
double estimated_count = item.second;
390-
EXPECT_LE(true_count, estimated_count + 1e8);
390+
EXPECT_LE(true_count, estimated_count + 1e-8);
391391
EXPECT_LE(estimated_count - mg.GetDelta(), true_count);
392392
}
393393
}

0 commit comments

Comments
 (0)