diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index eda400af7..e5ef9baad 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -976,13 +976,15 @@ tf_cc_library( ], deps = [ "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_sentencepiece//:sentencepiece_cc_proto", "@com_google_sentencepiece//:sentencepiece_model_cc_proto", "@com_google_sentencepiece//:sentencepiece_processor", + # tf:core_cpu_base tensorflow dep, ], ) @@ -1015,7 +1017,7 @@ tf_cc_library( # tf:lib tensorflow dep, ], deps = [ - "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@icu//:common", ], @@ -1071,7 +1073,7 @@ tf_cc_library( # tf:lib tensorflow dep, ], deps = [ - "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@icu//:common", ], @@ -1376,7 +1378,7 @@ tf_cc_library( ], deps = [ ":wordpiece_tokenizer", - "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", ], ) diff --git a/tensorflow_text/core/kernels/sentencepiece_kernels.cc b/tensorflow_text/core/kernels/sentencepiece_kernels.cc index 7b5d2b5fe..75bde4053 100644 --- a/tensorflow_text/core/kernels/sentencepiece_kernels.cc +++ b/tensorflow_text/core/kernels/sentencepiece_kernels.cc @@ -17,11 +17,12 @@ #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/meta/type_traits.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "src/sentencepiece_model.pb.h" #include "src/sentencepiece.pb.h" +#include "src/sentencepiece_model.pb.h" #include "src/sentencepiece_processor.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" @@ -33,6 +34,7 @@ #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" @@ -67,7 +69,7 @@ struct SentencepieceResource : public ResourceBase { (reverse == this->reverse); } - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { + sentencepiece::util::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { absl::ReaderMutexLock l(&mu); // We set use_node_name_sharing with a unique node name so that the resource // can outlive the kernel. This means that the lifetime of the re-created @@ -93,10 +95,10 @@ struct SentencepieceResource : public ResourceBase { // TODO(broken) Determine a medium cost of a call to the SentencePiece processor constexpr int64 kCostPerUnit = 10000; -::tensorflow::Status ToTFStatus(const sentencepiece::util::Status& s) { - if (s.ok()) return ::tensorflow::Status(); - return ::tensorflow::Status(static_cast<::tensorflow::errors::Code>(s.code()), - ::tensorflow::string(s.message())); +absl::Status ToTFStatus(const sentencepiece::util::Status& s) { + if (s.ok()) return sentencepiece::util::Status(); + return sentencepiece::util::Status(static_cast(s.code()), + ::tensorflow::string(s.message())); } template @@ -114,8 +116,8 @@ int32 GetPieceOrId( return sp.id(); } -tensorflow::Status HandleExtraOptions(OpKernelContext* ctx, - SentencepieceResource* sp) { +absl::Status HandleExtraOptions(OpKernelContext* ctx, + SentencepieceResource* sp) { const Tensor* add_bos_tensor = nullptr; TF_RETURN_IF_ERROR(ctx->input("add_bos", &add_bos_tensor)); const bool add_bos = add_bos_tensor->scalar()(); @@ -209,7 +211,7 @@ class SentencepieceOp : public OpKernel { GetNodeAttr(this->def(), "model", &model_proto_attr)); if (TF_PREDICT_FALSE(model_proto_attr.empty())) { - return Status(tensorflow::errors::InvalidArgument( + return sentencepiece::util::Status(tensorflow::errors::InvalidArgument( "Model argument must be specified.")); } // Loads serialized sentencepiece model proto to enable embedding diff --git a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc b/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc index fefac977f..fe622b8c8 100644 --- a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc +++ b/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "icu4c/source/common/unicode/uchar.h" #include "icu4c/source/common/unicode/umachine.h" @@ -64,24 +65,24 @@ bool IsBreakChar(absl::string_view text) { return u_isUWhiteSpace(c); } -Status TokenizeByLabel(const absl::string_view& text, - const Tensor& labels_tensor, - bool force_split_at_break_character, - std::vector* tokens, - std::vector* begin_offset, - std::vector* end_offset, int* num_tokens) { +absl::Status TokenizeByLabel(const absl::string_view& text, + const Tensor& labels_tensor, + bool force_split_at_break_character, + std::vector* tokens, + std::vector* begin_offset, + std::vector* end_offset, int* num_tokens) { std::vector chars; if (!GetUTF8Chars(text, &chars)) { - return Status(static_cast( - absl::StatusCode::kInvalidArgument), - absl::StrCat("Input string is not utf8 valid: ", text)); + return absl::Status( + static_cast(absl::StatusCode::kInvalidArgument), + absl::StrCat("Input string is not utf8 valid: ", text)); } if (chars.size() > labels_tensor.dim_size(0)) { - return Status(static_cast( - absl::StatusCode::kInvalidArgument), - absl::StrCat("Number of labels ", labels_tensor.dim_size(0), - " is insufficient for text ", text)); + return absl::Status( + static_cast(absl::StatusCode::kInvalidArgument), + absl::StrCat("Number of labels ", labels_tensor.dim_size(0), + " is insufficient for text ", text)); } const int split_label = 0; diff --git a/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc b/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc index 004dd9442..008a52af1 100644 --- a/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc +++ b/tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "icu4c/source/common/unicode/uchar.h" #include "icu4c/source/common/unicode/umachine.h" @@ -24,6 +25,7 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -68,22 +70,22 @@ bool IsBreakChar(absl::string_view text) { // allows us to retrieve the corresponding data from logits. I.e., the logits // for the i-th character from text are logits(batch_index, i, 0) (for the // "split" action) and logits(batch_index, i, 1) (for the "merge" action). -Status TokenizeByLogits(const absl::string_view& text, - const TTypes::Tensor& logits, - int batch_index, - bool force_split_at_break_character, - std::vector* tokens, - std::vector* begin_offset, - std::vector* end_offset, int* num_tokens) { +absl::Status TokenizeByLogits(const absl::string_view& text, + const TTypes::Tensor& logits, + int batch_index, + bool force_split_at_break_character, + std::vector* tokens, + std::vector* begin_offset, + std::vector* end_offset, int* num_tokens) { std::vector chars; if (!GetUTF8Chars(text, &chars)) { - return Status( + return absl::Status( static_cast(absl::StatusCode::kInvalidArgument), absl::StrCat("Input string is not utf8 valid: ", text)); } if (chars.size() > logits.dimension(1)) { - return Status( + return absl::Status( static_cast(absl::StatusCode::kInvalidArgument), absl::StrCat("Number of logits, ", logits.dimension(1), ", is insufficient for text \"", text, "\"")); diff --git a/tensorflow_text/core/kernels/wordpiece_kernel.cc b/tensorflow_text/core/kernels/wordpiece_kernel.cc index c8c1faa95..e91bf4703 100644 --- a/tensorflow_text/core/kernels/wordpiece_kernel.cc +++ b/tensorflow_text/core/kernels/wordpiece_kernel.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" #include "tensorflow/core/framework/lookup_interface.h" #include "tensorflow/core/framework/op_kernel.h" @@ -82,8 +83,8 @@ bool GetSplitUnknownCharacters(OpKernelConstruction* ctx) { return split_unknown_characters; } -Status GetTableHandle(const string& input_name, OpKernelContext* ctx, - string* container, string* table_handle) { +absl::Status GetTableHandle(const string& input_name, OpKernelContext* ctx, + string* container, string* table_handle) { { mutex* mu; TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); @@ -105,8 +106,8 @@ Status GetTableHandle(const string& input_name, OpKernelContext* ctx, // Gets the LookupTable stored in the ctx->resource_manager() with key // passed by attribute with name input_name, returns null if the table // doesn't exist. -Status GetLookupTable(const string& input_name, OpKernelContext* ctx, - lookup::LookupInterface** table) { +absl::Status GetLookupTable(const string& input_name, OpKernelContext* ctx, + lookup::LookupInterface** table) { string container; string table_handle; DataType handle_dtype; @@ -135,7 +136,7 @@ class LookupTableVocab : public WordpieceVocab { Tensor default_value_; }; -Status ToStatus(const LookupStatus& status) { +absl::Status ToStatus(const LookupStatus& status) { if (status.success) { return absl::OkStatus(); }