diff --git a/lite/examples/bert_qa/android/.gitignore b/lite/examples/bert_qa/android/.gitignore new file mode 100644 index 00000000000..e29db5dadda --- /dev/null +++ b/lite/examples/bert_qa/android/.gitignore @@ -0,0 +1,3 @@ +/app/src/main/assets/ +/lib_interpreter/src/main/assets/ +/lib_task_api/src/main/assets/ \ No newline at end of file diff --git a/lite/examples/bert_qa/android/EXPLORE_THE_CODE.md b/lite/examples/bert_qa/android/EXPLORE_THE_CODE.md new file mode 100644 index 00000000000..21a6d3159ac --- /dev/null +++ b/lite/examples/bert_qa/android/EXPLORE_THE_CODE.md @@ -0,0 +1,169 @@ +# TensorFlow Lite BERT QA Android example + +This document walks through the code of a simple Android mobile application that +demonstrates +[BERT Question and Answer](https://www.tensorflow.org/lite/examples/bert_qa/overview). + +## Explore the code + +The app is written entirely in Java and uses the TensorFlow Lite +[Java library](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/java) +for performing BERT Question and Answer. + +We're now going to walk through the most important parts of the sample code. + +### Get the question and the context of the question + +This mobile application gets the question and the context of the question using the functions defined in the +file +[`QaActivity.java`](https://github.com/tensorflow/examples/blob/master/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QaActivity.java). + + +### Answerer + +This BERT QA Android reference app demonstrates two implementation +solutions, +[`lib_task_api`](/lite/examples/bert_qa/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer), +and +[`lib_interpreter`](/lite/examples/bert_qa/android/lib_interpreter) +that creates the custom inference pipleline using the +[TensorFlow Lite Interpreter Java API](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_java). + +Both solutions implement the file `QaClient.java` (see +[the one in lib_task_api](/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java) +and +[the one in lib_interpreter](/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java) +that contains most of the complex logic for processing the text input and +running inference. + +#### Using the TensorFlow Lite Task Library + +Inference can be done using just a few lines of code with the +[`BertQuestionAnswerer`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer) +in the TensorFlow Lite Task Library. + +##### Load model and create BertQuestionAnswerer + +`BertQuestionAnswerer` expects a model populated with the +[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label +file. See the +[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer#model_compatibility_requirements) +for more details. + + +```java +/** + * Load TFLite model and create BertQuestionAnswerer instance. + */ + public void loadModel() { + try { + answerer = BertQuestionAnswerer.createFromFile(context, MODEL_PATH); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } +``` + +`BertQuestionAnswerer` currently does not support configuring delegates and +multithread, but those are on our roadmap. Please stay tuned! + +##### Run inference + +The following code runs inference using `BertQuestionAnswerer` and predicts the possible answers + +```java + /** + * Run inference and predict the possible answers. + */ + List apiResult = answerer.answer(contextOfTheQuestion, questionToAsk); + +``` + +The output of `BertQuestionAnswerer` is a list of [`QaAnswer`](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java) instance, where +each `QaAnswer` element is a single head classification result. All the +demo models are single head models, therefore, `results` only contains one +`QaAnswer` object. + +To match the implementation of +[`lib_interpreter`](/lite/examples/bert_qa/android/lib_interpreter), +`results` is converted into List<[`Answer`](/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java)>. + +#### Using the TensorFlow Lite Interpreter + +##### Load model and create interpreter + +To perform inference, we need to load a model file and instantiate an +`Interpreter`. This happens in the `loadModel` method of the `QaClient` class. Information about number of threads is used to configure the `Interpreter` via the +`Interpreter.Options` instance passed into its constructor. + +```java +Interpreter.Options opt = new Interpreter.Options(); + opt.setNumThreads(NUM_LITE_THREADS); + tflite = new Interpreter(buffer, opt); +... +``` + +##### Pre-process query & content + +Next in the `predict` method of the `QaClient` class, we take the input of query & content, +convert it to a `Feature` format for efficient processing and pre-process +it. The steps are shown in the public 'FeatureConverter.convert()' method: + +```java + +public Feature convert(String query, String context) { + List queryTokens = tokenizer.tokenize(query); + if (queryTokens.size() > maxQueryLen) { + queryTokens = queryTokens.subList(0, maxQueryLen); + } + + List origTokens = Arrays.asList(context.trim().split("\\s+")); + List tokenToOrigIndex = new ArrayList<>(); + List allDocTokens = new ArrayList<>(); + for (int i = 0; i < origTokens.size(); i++) { + String token = origTokens.get(i); + List subTokens = tokenizer.tokenize(token); + for (String subToken : subTokens) { + tokenToOrigIndex.add(i); + allDocTokens.add(subToken); + } + } + +``` + +##### Run inference + +Inference is performed using the following in `QaClient` class: + +```java +tflite.runForMultipleInputsOutputs(inputs, output); +``` + +### Display results + +The QaClient is invoked and inference results are displayed by the +`presentAnswer()` function in +[`QaActivity.java`](/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QaActivity.java). + +```java +private void presentAnswer(Answer answer) { + // Highlight answer. + Spannable spanText = new SpannableString(content); + int offset = content.indexOf(answer.text, 0); + if (offset >= 0) { + spanText.setSpan( + new BackgroundColorSpan(getColor(R.color.tfe_qa_color_highlight)), + offset, + offset + answer.text.length(), + Spannable.SPAN_EXCLUSIVE_EXCLUSIVE); + } + contentTextView.setText(spanText); + + // Use TTS to speak out the answer. + if (textToSpeech != null) { + textToSpeech.speak(answer.text, TextToSpeech.QUEUE_FLUSH, null, answer.text); + } + } +``` \ No newline at end of file diff --git a/lite/examples/bert_qa/android/README.md b/lite/examples/bert_qa/android/README.md index 271c7024ff5..8fe06715cd7 100644 --- a/lite/examples/bert_qa/android/README.md +++ b/lite/examples/bert_qa/android/README.md @@ -1,13 +1,29 @@ # TensorFlow Lite BERT QA Android Example Application + +Video + ## Overview -This is an end-to-end example of BERT Question & Answer application built with -TensorFlow 2.0, and tested on SQuAD dataset. The demo app provides 48 passages -from the dataset for users to choose from, and gives 5 most possible answers -corresponding to the input passage and query. +This is an end-to-end example of [BERT] Question & Answer application built with +TensorFlow 2.0, and tested on [SQuAD] dataset version 1.1. The demo app provides +48 passages from the dataset for users to choose from, and gives 5 most possible +answers corresponding to the input passage and query. + +These instructions walk you through running the demo on an Android device. For an explanation of the source, see +[TensorFlow Lite BERT QA Android example](EXPLORE_THE_CODE.md). + +### Model used + +[BERT], or Bidirectional Encoder Representations from Transformers, is a method +of pre-training language representations which obtains state-of-the-art results +on a wide array of Natural Language Processing tasks. + +This app uses [MobileBERT], a compressed version of [BERT] that runs 4x faster and +has 4x smaller model size. + +For more information, refer to the [BERT github page][BERT]. -These instructions walk you through running the demo on an Android device. ## Build the demo using Android Studio @@ -19,10 +35,10 @@ These instructions walk you through running the demo on an Android device. * Android Studio 3.2 or later. - Gradle 4.6 or higher. - - SDK Build Tools 28.0.3 or higher. + - SDK Build Tools 29.0.2 or higher. * You need an Android device or Android emulator and Android development - environment with minimum API 15. + environment with minimum API 21. ### Building @@ -41,7 +57,7 @@ These instructions walk you through running the demo on an Android device. ### Running * You need to have an Android device plugged in with developer options enabled - at this point. See [here](https://developer.android.com/studio/run/device) + at this point. See [here](https://developer.android.com/studio/run/device "Download Link") for more details on setting up developer devices. * If you already have Android emulator installed in Android Studio, select a @@ -49,6 +65,38 @@ These instructions walk you through running the demo on an Android device. * Click `Run` to run the demo app on your Android device. +#### Switch between inference solutions (Task library vs TFLite Interpreter) + +This BERT QA Android reference app demonstrates two implementation +solutions: + +(1) +[`lib_task_api`](/lite/examples/bert_qa/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer); + +(2) +[`lib_interpreter`](/lite/examples/bert_qa/android/lib_interpreter) +that creates the custom inference pipleline using the +[TensorFlow Lite Interpreter Java API](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_java). + +The [`build.gradle`](app/build.gradle) inside `app` folder shows how to change +`flavorDimensions "tfliteInference"` to switch between the two solutions. + +Inside **Android Studio**, you can change the build variant to whichever one you +want to build and run—just go to `Build > Select Build Variant` and select one +from the drop-down menu. See +[configure product flavors in Android Studio](https://developer.android.com/studio/build/build-variants#product-flavors) +for more details. + +For gradle CLI, running `./gradlew build` can create APKs for both solutions +under `app/build/outputs/apk`. + +*Note: If you simply want the out-of-box API to run the app, we recommend +`lib_task_api`for inference. If you want to customize your own models and +control the detail of inputs and outputs, it might be easier to adapt your model +inputs and outputs by using `lib_interpreter`.* + ## Build the demo using gradle (command line) ### Building and Installing @@ -66,3 +114,13 @@ cd lite/examples/bert_qa/android # Folder for Android app. ``` adb install app/build/outputs/apk/debug/app-debug.apk ``` + +## Assets folder + +_Do not delete the assets folder content_. If you explicitly deleted the files, +choose `Build -> Rebuild` to re-download the deleted model files into the assets +folder. + +[BERT]: https://github.com/google-research/bert "Bert" +[SQuAD]: https://rajpurkar.github.io/SQuAD-explorer/ "SQuAD" +[MobileBERT]:https://tfhub.dev/tensorflow/tfjs-model/mobilebert/1 "MobileBERT" diff --git a/lite/examples/bert_qa/android/app/build.gradle b/lite/examples/bert_qa/android/app/build.gradle index 81dda5a0d1e..1695318e544 100644 --- a/lite/examples/bert_qa/android/app/build.gradle +++ b/lite/examples/bert_qa/android/app/build.gradle @@ -2,10 +2,10 @@ apply plugin: 'com.android.application' android { compileSdkVersion 29 - buildToolsVersion "29.0.0" + buildToolsVersion "29.0.2" defaultConfig { applicationId "org.tensorflow.lite.examples.bertapp" - minSdkVersion 26 + minSdkVersion 21 targetSdkVersion 29 versionCode 1 versionName "1.0" @@ -34,32 +34,33 @@ android { lintOptions { abortOnError false } -} -// App assets folder: -project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets/' + flavorDimensions "tfliteInference" + productFlavors { + // The TFLite inference is built using the TFLite Java interpreter. + interpreter { + dimension "tfliteInference" + } + // Default: The TFLite inference is built using the TFLite Task library (high-level API). + taskApi { + getIsDefault().set(true) + dimension "tfliteInference" + } + } +} -// Download TF Lite model. -apply from: 'download.gradle' dependencies { - implementation fileTree(dir: 'libs', include: ['*.jar']) + interpreterImplementation project(":lib_interpreter") + taskApiImplementation project(":lib_task_api") implementation 'androidx.appcompat:appcompat:1.1.0' implementation 'androidx.constraintlayout:constraintlayout:1.1.3' implementation 'androidx.coordinatorlayout:coordinatorlayout:1.1.0' implementation 'androidx.recyclerview:recyclerview:1.1.0' implementation 'com.google.android.material:material:1.0.0' - implementation 'com.google.code.gson:gson:2.8.5' - implementation 'com.google.guava:guava:28.1-android' - implementation ('org.tensorflow:tensorflow-lite:2.4.0') - implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0' - testImplementation 'junit:junit:4.12' - testImplementation 'androidx.test:core:1.2.0' - testImplementation 'com.google.truth:truth:1.0' - testImplementation 'org.robolectric:robolectric:4.3.1' androidTestImplementation 'androidx.test:runner:1.2.0' androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' } diff --git a/lite/examples/bert_qa/android/app/src/main/AndroidManifest.xml b/lite/examples/bert_qa/android/app/src/main/AndroidManifest.xml index 68b68d37d39..c32d769dbb2 100644 --- a/lite/examples/bert_qa/android/app/src/main/AndroidManifest.xml +++ b/lite/examples/bert_qa/android/app/src/main/AndroidManifest.xml @@ -28,16 +28,16 @@ tools:ignore="GoogleAppIndexingWarning"> + android:parentActivityName=".DatasetListActivity"> + android:value=".DatasetListActivity" /> diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/DatasetListActivity.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/DatasetListActivity.java similarity index 60% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/DatasetListActivity.java rename to lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/DatasetListActivity.java index 80441f87f02..c753e985692 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/DatasetListActivity.java +++ b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/DatasetListActivity.java @@ -12,14 +12,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.lite.examples.bertqa.ui; +package org.tensorflow.lite.examples.bertqa; import android.os.Bundle; -import androidx.appcompat.app.AppCompatActivity; import android.widget.ArrayAdapter; import android.widget.ListView; -import org.tensorflow.lite.examples.bertqa.R; -import org.tensorflow.lite.examples.bertqa.ml.LoadDatasetClient; + +import androidx.appcompat.app.AppCompatActivity; /** * An activity representing a list of Datasets. This activity has different presentations for @@ -29,23 +28,23 @@ */ public class DatasetListActivity extends AppCompatActivity { - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.tfe_qa_activity_dataset_list); - - ListView listView = findViewById(R.id.dataset_list); - assert listView != null; - - LoadDatasetClient datasetClient = new LoadDatasetClient(this); - ArrayAdapter datasetAdapter = - new ArrayAdapter<>( - this, android.R.layout.simple_selectable_list_item, datasetClient.getTitles()); - listView.setAdapter(datasetAdapter); - - listView.setOnItemClickListener( - (parent, view, position, id) -> { - startActivity(QaActivity.newInstance(this, position)); - }); - } + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.tfe_qa_activity_dataset_list); + + ListView listView = findViewById(R.id.dataset_list); + assert listView != null; + + LoadDatasetClient datasetClient = new LoadDatasetClient(this); + ArrayAdapter datasetAdapter = + new ArrayAdapter<>( + this, android.R.layout.simple_selectable_list_item, datasetClient.getTitles()); + listView.setAdapter(datasetAdapter); + + listView.setOnItemClickListener( + (parent, view, position, id) -> { + startActivity(QaActivity.newInstance(this, position)); + }); + } } diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/LoadDatasetClient.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/LoadDatasetClient.java new file mode 100644 index 00000000000..919951bda67 --- /dev/null +++ b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/LoadDatasetClient.java @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa; + +import android.content.Context; +import android.util.Log; + +import com.google.gson.Gson; +import com.google.gson.stream.JsonReader; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.HashMap; +import java.util.List; + +/** + * Interface to load squad dataset. Provide passages for users to choose from & provide questions + * for autoCompleteTextView. + */ +public class LoadDatasetClient { + private static final String TAG = "BertAppDemo"; + private static final String JSON_DIR = "qa.json"; + private final Context context; + + private String[] contents; + private String[] titles; + private String[][] questions; + + public LoadDatasetClient(Context context) { + this.context = context; + loadJson(); + } + + private void loadJson() { + try { + InputStream is = context.getAssets().open(JSON_DIR); + JsonReader reader = new JsonReader(new InputStreamReader(is)); + HashMap>> map = new Gson().fromJson(reader, HashMap.class); + List> jsonTitles = map.get("titles"); + List> jsonContents = map.get("contents"); + List> jsonQuestions = map.get("questions"); + + titles = listToArray(jsonTitles); + contents = listToArray(jsonContents); + + questions = new String[jsonQuestions.size()][]; + int index = 0; + for (List item : jsonQuestions) { + questions[index++] = item.toArray(new String[item.size()]); + } + } catch (IOException ex) { + Log.e(TAG, ex.toString()); + } + } + + private static String[] listToArray(List> list) { + String[] answer = new String[list.size()]; + int index = 0; + for (List item : list) { + answer[index++] = item.get(0); + } + return answer; + } + + public String[] getTitles() { + return titles; + } + + public String getContent(int index) { + return contents[index]; + } + + public String[] getQuestions(int index) { + return questions[index]; + } + +} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QaActivity.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QaActivity.java new file mode 100644 index 00000000000..ee394a9a5a1 --- /dev/null +++ b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QaActivity.java @@ -0,0 +1,268 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa; + +import android.content.Context; +import android.content.Intent; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.speech.tts.TextToSpeech; +import android.text.Editable; +import android.text.Spannable; +import android.text.SpannableString; +import android.text.TextWatcher; +import android.text.method.ScrollingMovementMethod; +import android.text.style.BackgroundColorSpan; +import android.util.Log; +import android.view.KeyEvent; +import android.view.View; +import android.view.inputmethod.InputMethodManager; +import android.widget.ImageButton; +import android.widget.TextView; + +import androidx.appcompat.app.AppCompatActivity; +import androidx.recyclerview.widget.LinearLayoutManager; +import androidx.recyclerview.widget.RecyclerView; + +import com.google.android.material.snackbar.Snackbar; +import com.google.android.material.textfield.TextInputEditText; + +import org.tensorflow.lite.examples.bertqa.ml.Answer; +import org.tensorflow.lite.examples.bertqa.ml.QaClient; + +import java.util.List; +import java.util.Locale; + +/** + * Activity for doing Q&A on a specific dataset + */ +public class QaActivity extends AppCompatActivity { + + private static final String DATASET_POSITION_KEY = "DATASET_POSITION"; + private static final String TAG = "QaActivity"; + private static final boolean DISPLAY_RUNNING_TIME = false; + + private TextInputEditText questionEditText; + private TextView contentTextView; + private TextToSpeech textToSpeech; + + private boolean questionAnswered = false; + private String content; + private Handler handler; + private QaClient qaClient; + + public static Intent newInstance(Context context, int datasetPosition) { + Intent intent = new Intent(context, QaActivity.class); + intent.putExtra(DATASET_POSITION_KEY, datasetPosition); + return intent; + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + Log.v(TAG, "onCreate"); + super.onCreate(savedInstanceState); + setContentView(R.layout.tfe_qa_activity_qa); + + // Get content of the selected dataset. + int datasetPosition = getIntent().getIntExtra(DATASET_POSITION_KEY, -1); + LoadDatasetClient datasetClient = new LoadDatasetClient(this); + + // Show the dataset title. + TextView titleText = findViewById(R.id.title_text); + titleText.setText(datasetClient.getTitles()[datasetPosition]); + + // Show the text content of the selected dataset. + content = datasetClient.getContent(datasetPosition); + contentTextView = findViewById(R.id.content_text); + contentTextView.setText(content); + contentTextView.setMovementMethod(new ScrollingMovementMethod()); + + // Setup question suggestion list. + RecyclerView questionSuggestionsView = findViewById(R.id.suggestion_list); + QuestionAdapter adapter = + new QuestionAdapter(this, datasetClient.getQuestions(datasetPosition)); + adapter.setOnQuestionSelectListener(question -> answerQuestion(question)); + questionSuggestionsView.setAdapter(adapter); + LinearLayoutManager layoutManager = + new LinearLayoutManager(this, LinearLayoutManager.HORIZONTAL, false); + questionSuggestionsView.setLayoutManager(layoutManager); + + // Setup ask button. + ImageButton askButton = findViewById(R.id.ask_button); + askButton.setOnClickListener( + view -> answerQuestion(questionEditText.getText().toString())); + + // Setup text edit where users can input their question. + questionEditText = findViewById(R.id.question_edit_text); + questionEditText.setOnFocusChangeListener( + (view, hasFocus) -> { + // If we already answer current question, clear the question so that user can input a new + // one. + if (hasFocus && questionAnswered) { + questionEditText.setText(null); + } + }); + questionEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence charSequence, int i, int i1, int i2) { + } + + @Override + public void onTextChanged(CharSequence charSequence, int i, int i1, int i2) { + // Only allow clicking Ask button if there is a question. + boolean shouldAskButtonActive = !charSequence.toString().isEmpty(); + askButton.setClickable(shouldAskButtonActive); + askButton.setImageResource( + shouldAskButtonActive ? R.drawable.ic_ask_active : R.drawable.ic_ask_inactive); + } + + @Override + public void afterTextChanged(Editable editable) { + } + }); + questionEditText.setOnKeyListener( + (v, keyCode, event) -> { + if (event.getAction() == KeyEvent.ACTION_UP && keyCode == KeyEvent.KEYCODE_ENTER) { + answerQuestion(questionEditText.getText().toString()); + } + return false; + }); + + // Setup QA client to and background thread to run inference. + HandlerThread handlerThread = new HandlerThread("QAClient"); + handlerThread.start(); + handler = new Handler(handlerThread.getLooper()); + qaClient = new QaClient(this); + } + + @Override + protected void onStart() { + Log.v(TAG, "onStart"); + super.onStart(); + handler.post( + () -> { + qaClient.loadModel(); + }); + + textToSpeech = + new TextToSpeech( + this, + status -> { + if (status == TextToSpeech.SUCCESS) { + textToSpeech.setLanguage(Locale.US); + } else { + textToSpeech = null; + } + }); + } + + @Override + protected void onStop() { + Log.v(TAG, "onStop"); + super.onStop(); + handler.post(() -> qaClient.unload()); + + if (textToSpeech != null) { + textToSpeech.stop(); + textToSpeech.shutdown(); + } + } + + private void answerQuestion(String question) { + question = question.trim(); + if (question.isEmpty()) { + questionEditText.setText(question); + return; + } + + // Append question mark '?' if not ended with '?'. + // This aligns with question format that trains the model. + if (!question.endsWith("?")) { + question += '?'; + } + final String questionToAsk = question; + questionEditText.setText(questionToAsk); + + // Delete all pending tasks. + handler.removeCallbacksAndMessages(null); + + // Hide keyboard and dismiss focus on text edit. + InputMethodManager imm = + (InputMethodManager) getSystemService(AppCompatActivity.INPUT_METHOD_SERVICE); + imm.hideSoftInputFromWindow(getWindow().getDecorView().getWindowToken(), 0); + View focusView = getCurrentFocus(); + if (focusView != null) { + focusView.clearFocus(); + } + + // Reset content text view + contentTextView.setText(content); + + questionAnswered = false; + + // Start showing Looking up snackbar + Snackbar runningSnackbar = + Snackbar.make(contentTextView, "Looking up answer...", Snackbar.LENGTH_INDEFINITE); + runningSnackbar.show(); + + // Run TF Lite model to get the answer. + handler.post( + () -> { + long beforeTime = System.currentTimeMillis(); + final List answers = qaClient.predict(questionToAsk, content); + long afterTime = System.currentTimeMillis(); + double totalSeconds = (afterTime - beforeTime) / 1000.0; + + if (!answers.isEmpty()) { + // Get the top answer + Answer topAnswer = answers.get(0); + // Dismiss the snackbar and show the answer. + runOnUiThread( + () -> { + runningSnackbar.dismiss(); + presentAnswer(topAnswer); + + String displayMessage = "Top answer was successfully highlighted."; + if (DISPLAY_RUNNING_TIME) { + displayMessage = String.format("%s %.3fs.", displayMessage, totalSeconds); + } + Snackbar.make(contentTextView, displayMessage, Snackbar.LENGTH_LONG).show(); + questionAnswered = true; + }); + } + }); + } + + private void presentAnswer(Answer answer) { + // Highlight answer. + Spannable spanText = new SpannableString(content); + int offset = content.indexOf(answer.text, 0); + if (offset >= 0) { + spanText.setSpan( + new BackgroundColorSpan(getColor(R.color.tfe_qa_color_highlight)), + offset, + offset + answer.text.length(), + Spannable.SPAN_EXCLUSIVE_EXCLUSIVE); + } + contentTextView.setText(spanText); + + // Use TTS to speak out the answer. + if (textToSpeech != null) { + textToSpeech.speak(answer.text, TextToSpeech.QUEUE_FLUSH, null, answer.text); + } + } +} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QuestionAdapter.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QuestionAdapter.java new file mode 100644 index 00000000000..e53ea9f7168 --- /dev/null +++ b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/QuestionAdapter.java @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa; + +import android.content.Context; + +import androidx.recyclerview.widget.RecyclerView; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; + +import androidx.annotation.NonNull; + +import com.google.android.material.chip.Chip; + +/** + * Adapter class to show question suggestion chips. + */ +public class QuestionAdapter extends RecyclerView.Adapter { + + private LayoutInflater inflater; + private String[] questions; + private OnQuestionSelectListener onQuestionSelectListener; + + public QuestionAdapter(Context context, String[] questions) { + inflater = LayoutInflater.from(context); + this.questions = questions; + } + + @Override + public QuestionAdapter.MyViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) { + + View view = inflater.inflate(R.layout.tfe_qa_question_chip, parent, false); + MyViewHolder holder = new MyViewHolder(view); + + return holder; + } + + @Override + public void onBindViewHolder(QuestionAdapter.MyViewHolder holder, int position) { + holder.chip.setText(questions[position]); + holder.chip.setOnClickListener( + view -> onQuestionSelectListener.onQuestionSelect(questions[position])); + } + + @Override + public int getItemCount() { + return questions.length; + } + + public void setOnQuestionSelectListener(OnQuestionSelectListener onQuestionSelectListener) { + this.onQuestionSelectListener = onQuestionSelectListener; + } + + class MyViewHolder extends RecyclerView.ViewHolder { + + Chip chip; + + public MyViewHolder(View itemView) { + super(itemView); + chip = itemView.findViewById(R.id.chip); + } + } + + /** + * Interface for callback when a question is selected. + */ + public interface OnQuestionSelectListener { + void onQuestionSelect(String question); + } +} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaAnswer.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaAnswer.java deleted file mode 100644 index 077527e0fde..00000000000 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaAnswer.java +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.lite.examples.bertqa.ml; - -/** QA Answer class. */ -public class QaAnswer { - public Pos pos; - public String text; - - public QaAnswer(String text, Pos pos) { - this.text = text; - this.pos = pos; - } - - public QaAnswer(String text, int start, int end, float logit) { - this(text, new Pos(start, end, logit)); - } - - /** Position and related information from the model. */ - public static class Pos implements Comparable { - public int start; - public int end; - public float logit; - - public Pos(int start, int end, float logit) { - this.start = start; - this.end = end; - this.logit = logit; - } - - @Override - public int compareTo(Pos other) { - return Float.compare(other.logit, this.logit); - } - } -} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QaActivity.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QaActivity.java deleted file mode 100644 index 09f251b7c30..00000000000 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QaActivity.java +++ /dev/null @@ -1,261 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.lite.examples.bertqa.ui; - -import android.content.Context; -import android.content.Intent; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.speech.tts.TextToSpeech; -import androidx.appcompat.app.AppCompatActivity; -import androidx.recyclerview.widget.LinearLayoutManager; -import androidx.recyclerview.widget.RecyclerView; -import android.text.Editable; -import android.text.Spannable; -import android.text.SpannableString; -import android.text.TextWatcher; -import android.text.method.ScrollingMovementMethod; -import android.text.style.BackgroundColorSpan; -import android.util.Log; -import android.view.KeyEvent; -import android.view.View; -import android.view.inputmethod.InputMethodManager; -import android.widget.ImageButton; -import android.widget.TextView; -import com.google.android.material.snackbar.Snackbar; -import com.google.android.material.textfield.TextInputEditText; -import java.util.List; -import java.util.Locale; -import org.tensorflow.lite.examples.bertqa.R; -import org.tensorflow.lite.examples.bertqa.ml.LoadDatasetClient; -import org.tensorflow.lite.examples.bertqa.ml.QaAnswer; -import org.tensorflow.lite.examples.bertqa.ml.QaClient; - -/** Activity for doing Q&A on a specific dataset */ -public class QaActivity extends AppCompatActivity { - - private static final String DATASET_POSITION_KEY = "DATASET_POSITION"; - private static final String TAG = "QaActivity"; - private static final boolean DISPLAY_RUNNING_TIME = false; - - private TextInputEditText questionEditText; - private TextView contentTextView; - private TextToSpeech textToSpeech; - - private boolean questionAnswered = false; - private String content; - private Handler handler; - private QaClient qaClient; - - public static Intent newInstance(Context context, int datasetPosition) { - Intent intent = new Intent(context, QaActivity.class); - intent.putExtra(DATASET_POSITION_KEY, datasetPosition); - return intent; - } - - @Override - protected void onCreate(Bundle savedInstanceState) { - Log.v(TAG, "onCreate"); - super.onCreate(savedInstanceState); - setContentView(R.layout.tfe_qa_activity_qa); - - // Get content of the selected dataset. - int datasetPosition = getIntent().getIntExtra(DATASET_POSITION_KEY, -1); - LoadDatasetClient datasetClient = new LoadDatasetClient(this); - - // Show the dataset title. - TextView titleText = findViewById(R.id.title_text); - titleText.setText(datasetClient.getTitles()[datasetPosition]); - - // Show the text content of the selected dataset. - content = datasetClient.getContent(datasetPosition); - contentTextView = findViewById(R.id.content_text); - contentTextView.setText(content); - contentTextView.setMovementMethod(new ScrollingMovementMethod()); - - // Setup question suggestion list. - RecyclerView questionSuggestionsView = findViewById(R.id.suggestion_list); - QuestionAdapter adapter = - new QuestionAdapter(this, datasetClient.getQuestions(datasetPosition)); - adapter.setOnQuestionSelectListener(question -> answerQuestion(question)); - questionSuggestionsView.setAdapter(adapter); - LinearLayoutManager layoutManager = - new LinearLayoutManager(this, LinearLayoutManager.HORIZONTAL, false); - questionSuggestionsView.setLayoutManager(layoutManager); - - // Setup ask button. - ImageButton askButton = findViewById(R.id.ask_button); - askButton.setOnClickListener( - view -> answerQuestion(questionEditText.getText().toString())); - - // Setup text edit where users can input their question. - questionEditText = findViewById(R.id.question_edit_text); - questionEditText.setOnFocusChangeListener( - (view, hasFocus) -> { - // If we already answer current question, clear the question so that user can input a new - // one. - if (hasFocus && questionAnswered) { - questionEditText.setText(null); - } - }); - questionEditText.addTextChangedListener( - new TextWatcher() { - @Override - public void beforeTextChanged(CharSequence charSequence, int i, int i1, int i2) {} - - @Override - public void onTextChanged(CharSequence charSequence, int i, int i1, int i2) { - // Only allow clicking Ask button if there is a question. - boolean shouldAskButtonActive = !charSequence.toString().isEmpty(); - askButton.setClickable(shouldAskButtonActive); - askButton.setImageResource( - shouldAskButtonActive ? R.drawable.ic_ask_active : R.drawable.ic_ask_inactive); - } - - @Override - public void afterTextChanged(Editable editable) {} - }); - questionEditText.setOnKeyListener( - (v, keyCode, event) -> { - if (event.getAction() == KeyEvent.ACTION_UP && keyCode == KeyEvent.KEYCODE_ENTER) { - answerQuestion(questionEditText.getText().toString()); - } - return false; - }); - - // Setup QA client to and background thread to run inference. - HandlerThread handlerThread = new HandlerThread("QAClient"); - handlerThread.start(); - handler = new Handler(handlerThread.getLooper()); - qaClient = new QaClient(this); - } - - @Override - protected void onStart() { - Log.v(TAG, "onStart"); - super.onStart(); - handler.post( - () -> { - qaClient.loadModel(); - }); - - textToSpeech = - new TextToSpeech( - this, - status -> { - if (status == TextToSpeech.SUCCESS) { - textToSpeech.setLanguage(Locale.US); - } else { - textToSpeech = null; - } - }); - } - - @Override - protected void onStop() { - Log.v(TAG, "onStop"); - super.onStop(); - handler.post(() -> qaClient.unload()); - - if (textToSpeech != null) { - textToSpeech.stop(); - textToSpeech.shutdown(); - } - } - - private void answerQuestion(String question) { - question = question.trim(); - if (question.isEmpty()) { - questionEditText.setText(question); - return; - } - - // Append question mark '?' if not ended with '?'. - // This aligns with question format that trains the model. - if (!question.endsWith("?")) { - question += '?'; - } - final String questionToAsk = question; - questionEditText.setText(questionToAsk); - - // Delete all pending tasks. - handler.removeCallbacksAndMessages(null); - - // Hide keyboard and dismiss focus on text edit. - InputMethodManager imm = - (InputMethodManager) getSystemService(AppCompatActivity.INPUT_METHOD_SERVICE); - imm.hideSoftInputFromWindow(getWindow().getDecorView().getWindowToken(), 0); - View focusView = getCurrentFocus(); - if (focusView != null) { - focusView.clearFocus(); - } - - // Reset content text view - contentTextView.setText(content); - - questionAnswered = false; - - Snackbar runningSnackbar = - Snackbar.make(contentTextView, "Looking up answer...", Integer.MAX_VALUE); - runningSnackbar.show(); - - // Run TF Lite model to get the answer. - handler.post( - () -> { - long beforeTime = System.currentTimeMillis(); - final List answers = qaClient.predict(questionToAsk, content); - long afterTime = System.currentTimeMillis(); - double totalSeconds = (afterTime - beforeTime) / 1000.0; - - if (!answers.isEmpty()) { - // Get the top answer - QaAnswer topAnswer = answers.get(0); - // Show the answer. - runOnUiThread( - () -> { - runningSnackbar.dismiss(); - presentAnswer(topAnswer); - - String displayMessage = "Top answer was successfully highlighted."; - if (DISPLAY_RUNNING_TIME) { - displayMessage = String.format("%s %.3fs.", displayMessage, totalSeconds); - } - Snackbar.make(contentTextView, displayMessage, Snackbar.LENGTH_LONG).show(); - questionAnswered = true; - }); - } - }); - } - - private void presentAnswer(QaAnswer answer) { - // Highlight answer. - Spannable spanText = new SpannableString(content); - int offset = content.indexOf(answer.text, 0); - if (offset >= 0) { - spanText.setSpan( - new BackgroundColorSpan(getColor(R.color.tfe_qa_color_highlight)), - offset, - offset + answer.text.length(), - Spannable.SPAN_EXCLUSIVE_EXCLUSIVE); - } - contentTextView.setText(spanText); - - // Use TTS to speak out the answer. - if (textToSpeech != null) { - textToSpeech.speak(answer.text, TextToSpeech.QUEUE_FLUSH, null, answer.text); - } - } -} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QuestionAdapter.java b/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QuestionAdapter.java deleted file mode 100644 index c4f69ac22d3..00000000000 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ui/QuestionAdapter.java +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.lite.examples.bertqa.ui; - -import android.content.Context; -import androidx.recyclerview.widget.RecyclerView; -import android.view.LayoutInflater; -import android.view.View; -import android.view.ViewGroup; -import androidx.annotation.NonNull; -import com.google.android.material.chip.Chip; -import org.tensorflow.lite.examples.bertqa.R; - -/** Adapter class to show question suggestion chips. */ -public class QuestionAdapter extends RecyclerView.Adapter { - - private LayoutInflater inflater; - private String[] questions; - private OnQuestionSelectListener onQuestionSelectListener; - - public QuestionAdapter(Context context, String[] questions) { - inflater = LayoutInflater.from(context); - this.questions = questions; - } - - @Override - public QuestionAdapter.MyViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) { - - View view = inflater.inflate(R.layout.tfe_qa_question_chip, parent, false); - MyViewHolder holder = new MyViewHolder(view); - - return holder; - } - - @Override - public void onBindViewHolder(QuestionAdapter.MyViewHolder holder, int position) { - holder.chip.setText(questions[position]); - holder.chip.setOnClickListener( - view -> onQuestionSelectListener.onQuestionSelect(questions[position])); - } - - @Override - public int getItemCount() { - return questions.length; - } - - public void setOnQuestionSelectListener(OnQuestionSelectListener onQuestionSelectListener) { - this.onQuestionSelectListener = onQuestionSelectListener; - } - - class MyViewHolder extends RecyclerView.ViewHolder { - - Chip chip; - - public MyViewHolder(View itemView) { - super(itemView); - chip = itemView.findViewById(R.id.chip); - } - } - - /** Interface for callback when a question is selected. */ - public interface OnQuestionSelectListener { - void onQuestionSelect(String question); - } -} diff --git a/lite/examples/bert_qa/android/app/src/main/res/layout/tfe_qa_activity_dataset_list.xml b/lite/examples/bert_qa/android/app/src/main/res/layout/tfe_qa_activity_dataset_list.xml index e188739e97d..e8415011186 100644 --- a/lite/examples/bert_qa/android/app/src/main/res/layout/tfe_qa_activity_dataset_list.xml +++ b/lite/examples/bert_qa/android/app/src/main/res/layout/tfe_qa_activity_dataset_list.xml @@ -56,7 +56,7 @@ app:layout_constraintEnd_toEndOf="parent" app:layout_constraintStart_toStartOf="parent" app:layout_constraintTop_toTopOf="parent" - tools:context="org.tensorflow.lite.examples.bertqa.ui.DatasetListActivity"> + tools:context="org.tensorflow.lite.examples.bertqa.DatasetListActivity"> + tools:context="org.tensorflow.lite.examples.bertqa.QaActivity"> predict0 = client.predict("What is Tesla's home country?", content); + assertThat(getTexts(predict0)).contains("Serbian"); + List predict1 = client.predict("What was Nikola Tesla's ethnicity?", content); + assertThat(getTexts(predict1)).contains("Serbian"); + List predict2 = client.predict("What does AC stand for?", content); + assertThat(getTexts(predict2)).contains("alternating current"); + List predict3 = client.predict("When was Tesla born?", content); + assertThat(getTexts(predict3)).contains("10 July 1856"); + } + + private static List getTexts(List answers) { + List texts = new ArrayList<>(); + for (Answer ans : answers) { + texts.add(ans.text); + } + return texts; + } +} diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizerTest.java b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizerTest.java new file mode 100644 index 00000000000..8648a919562 --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizerTest.java @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa.tokenization; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Tests of {@link BasicTokenizer} */ +@RunWith(AndroidJUnit4.class) +public final class BasicTokenizerTest { + @Test + public void cleanTextTest() throws Exception { + String testExample = "This is an\rexample.\n"; + char testChar = 0; + char unknownChar = 0xfffd; + assertThat(BasicTokenizer.cleanText(testExample)).isEqualTo("This is an example. "); + assertThat(BasicTokenizer.cleanText(testExample + testChar)).isEqualTo("This is an example. "); + assertThat(BasicTokenizer.cleanText(testExample + unknownChar)) + .isEqualTo("This is an example. "); + + String nullString = null; + Assert.assertThrows(NullPointerException.class, () -> BasicTokenizer.cleanText(nullString)); + } + + @Test + public void whitespaceTokenizeTest() throws Exception { + assertThat(BasicTokenizer.whitespaceTokenize("Hi , This is an example. ")) + .containsExactly("Hi", ",", "This", "is", "an", "example.") + .inOrder(); + assertThat(BasicTokenizer.whitespaceTokenize(" ")).isEmpty(); + + String nullString = null; + Assert.assertThrows( + NullPointerException.class, () -> BasicTokenizer.whitespaceTokenize(nullString)); + } + + @Test + public void runSplitOnPuncTest() throws Exception { + assertThat(BasicTokenizer.runSplitOnPunc("Hi,there.")) + .containsExactly("Hi", ",", "there", ".") + .inOrder(); + assertThat(BasicTokenizer.runSplitOnPunc("I'm \"Spider-Man\"")) // Input: I'm "Spider-Man" + .containsExactly("I", "\'", "m ", "\"", "Spider", "-", "Man", "\"") + .inOrder(); + + String nullString = null; + Assert.assertThrows( + NullPointerException.class, () -> BasicTokenizer.runSplitOnPunc(nullString)); + } + + @Test + public void tokenizeWithLowerCaseTest() throws Exception { + BasicTokenizer tokenizer = new BasicTokenizer(/*doLowerCase=*/ true); + assertThat(tokenizer.tokenize(" Hi, This\tis an example.\n")) + .containsExactly("hi", ",", "this", "is", "an", "example", ".") + .inOrder(); + assertThat(tokenizer.tokenize("Hello,How are you?")) + .containsExactly("hello", ",", "how", "are", "you", "?") + .inOrder(); + + String nullString = null; + Assert.assertThrows(NullPointerException.class, () -> tokenizer.tokenize(nullString)); + } + + @Test + public void tokenizeTest() throws Exception { + BasicTokenizer tokenizer = new BasicTokenizer(/*doLowerCase=*/ false); + assertThat(tokenizer.tokenize(" Hi, This\tis an example.\n")) + .containsExactly("Hi", ",", "This", "is", "an", "example", ".") + .inOrder(); + assertThat(tokenizer.tokenize("Hello,How are you?")) + .containsExactly("Hello", ",", "How", "are", "you", "?") + .inOrder(); + + String nullString = null; + Assert.assertThrows(NullPointerException.class, () -> tokenizer.tokenize(nullString)); + } +} diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizerTest.java b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizerTest.java new file mode 100644 index 00000000000..ff0b5bb1ca6 --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizerTest.java @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa.tokenization; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.bertqa.ml.ModelHelper; +import org.tensorflow.lite.support.metadata.MetadataExtractor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +/** Tests of {@link FullTokenizer} */ +@RunWith(AndroidJUnit4.class) +public final class FullTokenizerTest { + private Map dic; + + @Before + public void setUp() throws IOException { + ByteBuffer buffer = ModelHelper.loadModelFile(ApplicationProvider.getApplicationContext()); + MetadataExtractor metadataExtractor = new MetadataExtractor(buffer); + dic = ModelHelper.extractDictionary(metadataExtractor); + assertThat(dic).isNotNull(); + assertThat(dic).isNotEmpty(); + } + + @Test + public void tokenizeTest() throws Exception { + FullTokenizer tokenizer = new FullTokenizer(dic, /* doLowerCase= */ true); + assertThat(tokenizer.tokenize("Good morning, I'm your teacher.\n")) + .containsExactly("good", "morning", ",", "i", "'", "m", "your", "teacher", ".") + .inOrder(); + assertThat(tokenizer.tokenize("")).isEmpty(); + + String nullString = null; + Assert.assertThrows(NullPointerException.class, () -> tokenizer.tokenize(nullString)); + } + + @Test + public void convertTokensToIdsTest() throws Exception { + FullTokenizer tokenizer = new FullTokenizer(dic, /* doLowerCase= */ true); + List testExample = + Arrays.asList("good", "morning", ",", "i", "'", "m", "your", "teacher", "."); + assertThat(tokenizer.convertTokensToIds(testExample)) + .containsExactly(2204, 2851, 1010, 1045, 1005, 1049, 2115, 3836, 1012) + .inOrder(); + } +} diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizerTest.java b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizerTest.java new file mode 100644 index 00000000000..c05ea116c9c --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/androidTest/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizerTest.java @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa.tokenization; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.bertqa.ml.ModelHelper; +import org.tensorflow.lite.support.metadata.MetadataExtractor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Map; + +/** Tests of {@link WordpieceTokenizer} */ +@RunWith(AndroidJUnit4.class) +public final class WordpieceTokenizerTest { + private Map dic; + + @Before + public void setUp() throws IOException { + ByteBuffer buffer = ModelHelper.loadModelFile(ApplicationProvider.getApplicationContext()); + MetadataExtractor metadataExtractor = new MetadataExtractor(buffer); + dic = ModelHelper.extractDictionary(metadataExtractor); + assertThat(dic).isNotNull(); + assertThat(dic).isNotEmpty(); + } + + @Test + public void tokenizeTest() throws Exception { + WordpieceTokenizer tokenizer = new WordpieceTokenizer(dic); + assertThat(tokenizer.tokenize("meaningfully")).containsExactly("meaningful", "##ly").inOrder(); + assertThat(tokenizer.tokenize("teacher")).containsExactly("teacher").inOrder(); + + String nullString = null; + Assert.assertThrows(NullPointerException.class, () -> tokenizer.tokenize(nullString)); + } +} diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/main/AndroidManifest.xml b/lite/examples/bert_qa/android/lib_interpreter/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..22c3a8fc635 --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java new file mode 100644 index 00000000000..2d5f072276e --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java @@ -0,0 +1,39 @@ +package org.tensorflow.lite.examples.bertqa.ml; + +/** QA Answer class. */ +public class Answer implements Comparable{ + public Pos pos; + public String text; + + public Answer(String text, Pos pos) { + this.text = text; + this.pos = pos; + } + + public Answer(String text, int start, int end, float logit) { + this(text, new Pos(start, end, logit)); + } + + @Override + public int compareTo(Answer other) { + return Float.compare(other.pos.logit, this.pos.logit); + } + + /** Position and related information from the model. */ + public static class Pos implements Comparable { + public int start; + public int end; + public float logit; + + public Pos(int start, int end, float logit) { + this.start = start; + this.end = end; + this.logit = logit; + } + + @Override + public int compareTo(Pos other) { + return Float.compare(other.logit, this.logit); + } + } +} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java similarity index 99% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java index 2232eb5d3b1..c26880edc09 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Feature.java @@ -15,6 +15,7 @@ package org.tensorflow.lite.examples.bertqa.ml; import com.google.common.primitives.Ints; + import java.util.List; import java.util.Map; diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java similarity index 99% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java index 32fe4240358..f337b016315 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/FeatureConverter.java @@ -14,13 +14,14 @@ ==============================================================================*/ package org.tensorflow.lite.examples.bertqa.ml; +import org.tensorflow.lite.examples.bertqa.tokenization.FullTokenizer; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.tensorflow.lite.examples.bertqa.tokenization.FullTokenizer; /** Convert String to features that can be fed into BERT model. */ public final class FeatureConverter { diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java similarity index 82% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java index 585052c8966..b0662dc0acd 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/LoadDatasetClient.java @@ -16,15 +16,15 @@ import android.content.Context; import android.util.Log; + import com.google.gson.Gson; import com.google.gson.stream.JsonReader; -import java.io.BufferedReader; + import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.util.HashMap; import java.util.List; -import java.util.Map; /** * Interface to load squad dataset. Provide passages for users to choose from & provide questions @@ -33,7 +33,6 @@ public class LoadDatasetClient { private static final String TAG = "BertAppDemo"; private static final String JSON_DIR = "qa.json"; - private static final String DIC_DIR = "vocab.txt"; private final Context context; private String[] contents; @@ -87,19 +86,4 @@ public String getContent(int index) { public String[] getQuestions(int index) { return questions[index]; } - - public Map loadDictionary() { - Map dic = new HashMap<>(); - try (InputStream ins = context.getAssets().open(DIC_DIR); - BufferedReader reader = new BufferedReader(new InputStreamReader(ins))) { - int index = 0; - while (reader.ready()) { - String key = reader.readLine(); - dic.put(key, index++); - } - } catch (IOException ex) { - Log.e(TAG, ex.getMessage()); - } - return dic; - } } diff --git a/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/ModelHelper.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/ModelHelper.java new file mode 100644 index 00000000000..8b448af9d3a --- /dev/null +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/ModelHelper.java @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa.ml; + +import static com.google.common.base.Verify.verify; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.util.Log; + +import org.tensorflow.lite.support.metadata.MetadataExtractor; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.HashMap; +import java.util.Map; + +/** Helper to load TfLite model and dictionary. */ +public class ModelHelper { + private static final String TAG = "BertDemo"; + public static final String MODEL_PATH = "model.tflite"; + public static final String DIC_PATH = "vocab.txt"; + + private ModelHelper() {} + + /** Load tflite model from context. */ + public static MappedByteBuffer loadModelFile(Context context) throws IOException { + return loadModelFile(context.getAssets()); + } + + /** Load tflite model from assets. */ + public static MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException { + try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + } + + /** Extract dictionary from metadata. */ + public static Map extractDictionary(MetadataExtractor metadataExtractor) { + Map dic = null; + try { + verify(metadataExtractor != null, "metadataExtractor can't be null."); + dic = loadDictionaryFile(metadataExtractor.getAssociatedFile(DIC_PATH)); + Log.v(TAG, "Dictionary loaded."); + } catch (IOException ex) { + Log.e(TAG, ex.getMessage()); + } + return dic; + } + + /** Load dictionary from assets. */ + public static Map loadDictionaryFile(InputStream inputStream) + throws IOException { + Map dic = new HashMap<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { + int index = 0; + while (reader.ready()) { + String key = reader.readLine(); + dic.put(key, index++); + } + } + return dic; + } +} diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java similarity index 78% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java index a0d951446fe..c53e26535bb 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java @@ -17,33 +17,28 @@ import static com.google.common.base.Verify.verify; import android.content.Context; -import android.content.res.AssetFileDescriptor; -import android.content.res.AssetManager; import android.util.Log; + + import androidx.annotation.WorkerThread; + import com.google.common.base.Joiner; -import java.io.BufferedReader; -import java.io.FileInputStream; + +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.support.metadata.MetadataExtractor; +import org.tensorflow.lite.support.metadata.schema.TensorMetadata; + import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.support.metadata.MetadataExtractor; -import org.tensorflow.lite.support.metadata.schema.TensorMetadata; /** Interface to load TfLite model and provide predictions. */ public class QaClient implements AutoCloseable { private static final String TAG = "BertDemo"; - private static final String MODEL_PATH = "model.tflite"; - private static final String DIC_PATH = "vocab.txt"; private static final int MAX_ANS_LEN = 32; private static final int MAX_QUERY_LEN = 64; @@ -77,11 +72,14 @@ public QaClient(Context context) { @WorkerThread public synchronized void loadModel() { try { - ByteBuffer buffer = loadModelFile(this.context.getAssets()); + ByteBuffer buffer = ModelHelper.loadModelFile(context); + metadataExtractor = new MetadataExtractor(buffer); + Map loadedDic = ModelHelper.extractDictionary(metadataExtractor); + verify(loadedDic != null, "dic can't be null."); + dic.putAll(loadedDic); + Interpreter.Options opt = new Interpreter.Options(); opt.setNumThreads(NUM_LITE_THREADS); - metadataExtractor = new MetadataExtractor(buffer); - loadDictionary(); tflite = new Interpreter(buffer, opt); Log.v(TAG, "TFLite model loaded."); } catch (IOException ex) { @@ -89,17 +87,6 @@ public synchronized void loadModel() { } } - @WorkerThread - public synchronized void loadDictionary() { - try { - verify(metadataExtractor != null, "metadataExtractor can't be null."); - loadDictionaryFile(metadataExtractor.getAssociatedFile(DIC_PATH)); - Log.v(TAG, "Dictionary loaded."); - } catch (IOException ex) { - Log.e(TAG, ex.getMessage()); - } - } - @WorkerThread public synchronized void unload() { close(); @@ -114,27 +101,6 @@ public void close() { dic.clear(); } - /** Load tflite model from assets. */ - public MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException { - try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - } - - /** Load dictionary from assets. */ - public void loadDictionaryFile(InputStream inputStream) throws IOException { - try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { - int index = 0; - while (reader.ready()) { - String key = reader.readLine(); - dic.put(key, index++); - } - } - } /** * Input: Original content and query for the QA task. Later converted to Feature by @@ -142,8 +108,8 @@ public void loadDictionaryFile(InputStream inputStream) throws IOException { * logits. */ @WorkerThread - public synchronized List predict(String query, String content) { - Log.v(TAG, "TFLite model: " + MODEL_PATH + " running..."); + public synchronized List predict(String query, String content) { + Log.v(TAG, "TFLite model: " + ModelHelper.MODEL_PATH + " running..."); Log.v(TAG, "Convert Feature..."); Feature feature = featureConverter.convert(query, content); @@ -230,19 +196,19 @@ public synchronized List predict(String query, String content) { tflite.runForMultipleInputsOutputs(inputs, output); Log.v(TAG, "Convert answers..."); - List answers = getBestAnswers(startLogits[0], endLogits[0], feature); + List answers = getBestAnswers(startLogits[0], endLogits[0], feature); Log.v(TAG, "Finish."); return answers; } /** Find the Best N answers & logits from the logits array and input feature. */ - private synchronized List getBestAnswers( + private synchronized List getBestAnswers( float[] startLogits, float[] endLogits, Feature feature) { // Model uses the closed interval [start, end] for indices. int[] startIndexes = getBestIndex(startLogits); int[] endIndexes = getBestIndex(endLogits); - List origResults = new ArrayList<>(); + List origResults = new ArrayList<>(); for (int start : startIndexes) { for (int end : endIndexes) { if (!feature.tokenToOrigMap.containsKey(start + OUTPUT_OFFSET)) { @@ -258,13 +224,13 @@ private synchronized List getBestAnswers( if (length > MAX_ANS_LEN) { continue; } - origResults.add(new QaAnswer.Pos(start, end, startLogits[start] + endLogits[end])); + origResults.add(new Answer.Pos(start, end, startLogits[start] + endLogits[end])); } } Collections.sort(origResults); - List answers = new ArrayList<>(); + List answers = new ArrayList<>(); for (int i = 0; i < origResults.size(); i++) { if (i >= PREDICT_ANS_NUM) { break; @@ -276,7 +242,7 @@ private synchronized List getBestAnswers( } else { convertedText = ""; } - QaAnswer ans = new QaAnswer(convertedText, origResults.get(i)); + Answer ans = new Answer(convertedText, origResults.get(i)); answers.add(ans); } return answers; @@ -285,9 +251,9 @@ private synchronized List getBestAnswers( /** Get the n-best logits from a list of all the logits. */ @WorkerThread private synchronized int[] getBestIndex(float[] logits) { - List tmpList = new ArrayList<>(); + List tmpList = new ArrayList<>(); for (int i = 0; i < MAX_SEQ_LEN; i++) { - tmpList.add(new QaAnswer.Pos(i, i, logits[i])); + tmpList.add(new Answer.Pos(i, i, logits[i])); } Collections.sort(tmpList); diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java similarity index 99% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java index 991f0604fab..70b22de3fe7 100644 --- a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java +++ b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/BasicTokenizer.java @@ -16,6 +16,7 @@ import com.google.common.base.Ascii; import com.google.common.collect.Iterables; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/CharChecker.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/CharChecker.java similarity index 100% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/CharChecker.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/CharChecker.java diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizer.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizer.java similarity index 100% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizer.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/FullTokenizer.java diff --git a/lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizer.java b/lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizer.java similarity index 100% rename from lite/examples/bert_qa/android/app/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizer.java rename to lite/examples/bert_qa/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/bertqa/tokenization/WordpieceTokenizer.java diff --git a/lite/examples/bert_qa/android/lib_task_api/.gitignore b/lite/examples/bert_qa/android/lib_task_api/.gitignore new file mode 100644 index 00000000000..42afabfd2ab --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_task_api/build.gradle b/lite/examples/bert_qa/android/lib_task_api/build.gradle new file mode 100644 index 00000000000..024004d37dc --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/build.gradle @@ -0,0 +1,52 @@ +plugins { + id 'com.android.library' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles "consumer-rules.pro" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + aaptOptions { + noCompress "tflite" + } + testOptions { + unitTests { + includeAndroidResources = true + } + } +} + +// App assets folder: +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets/' + +// Download TF Lite model. +apply from: 'download.gradle' + +dependencies { + implementation 'org.tensorflow:tensorflow-lite-task-text:0.2.0' + + androidTestImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.1' + androidTestImplementation 'com.google.truth:truth:1.0' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' +} \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_task_api/consumer-rules.pro b/lite/examples/bert_qa/android/lib_task_api/consumer-rules.pro new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lite/examples/bert_qa/android/lib_task_api/download.gradle b/lite/examples/bert_qa/android/lib_task_api/download.gradle new file mode 100644 index 00000000000..5a6b670004f --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/download.gradle @@ -0,0 +1,15 @@ +apply plugin: 'de.undercouch.download' + +task downloadLiteModel { + def downloadFiles = [ + 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/bert_qa/contents_from_squad.json': 'qa.json', + 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/metadata/1?lite-format=tflite': 'model.tflite', + ] + downloadFiles.each { key, value -> + download { + src key + dest project.ext.ASSET_DIR + value + overwrite true + } + } +} \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_task_api/proguard-rules.pro b/lite/examples/bert_qa/android/lib_task_api/proguard-rules.pro new file mode 100644 index 00000000000..481bb434814 --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_task_api/src/androidTest/java/org/tensorflow/lite/examples/bertqa/QaClientTest.java b/lite/examples/bert_qa/android/lib_task_api/src/androidTest/java/org/tensorflow/lite/examples/bertqa/QaClientTest.java new file mode 100644 index 00000000000..c97518f2229 --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/src/androidTest/java/org/tensorflow/lite/examples/bertqa/QaClientTest.java @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite.examples.bertqa; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.bertqa.ml.Answer; +import org.tensorflow.lite.examples.bertqa.ml.QaClient; + +import java.util.ArrayList; +import java.util.List; + +/** Tests of {@link QaClient} */ +@RunWith(AndroidJUnit4.class) +public final class QaClientTest { + private QaClient client; + + @Before + public void setUp() { + client = new QaClient(ApplicationProvider.getApplicationContext()); + client.loadModel(); + } + + @Test + public void testGoldSet() { + final String content = + "Nikola Tesla (Serbian Cyrillic: \u041d\u0438\u043a\u043e\u043b\u0430" + + " \u0422\u0435\u0441\u043b\u0430; 10 July 1856 \u2013 7 January 1943) was a Serbian" + + " American inventor, electrical engineer, mechanical engineer, physicist, and" + + " futurist best known for his contributions to the design of the modern alternating" + + " current (AC) electricity supply system."; + + List predict0 = client.predict("What is Tesla's home country?", content); + assertThat(getTexts(predict0)).contains("Serbian"); + List predict1 = client.predict("What was Nikola Tesla's ethnicity?", content); + assertThat(getTexts(predict1)).contains("Serbian"); + List predict2 = client.predict("What does AC stand for?", content); + assertThat(getTexts(predict2)).contains("alternating current"); + List predict3 = client.predict("When was Tesla born?", content); + assertThat(getTexts(predict3)).contains("10 July 1856"); + } + + private static List getTexts(List answers) { + List texts = new ArrayList<>(); + for (Answer ans : answers) { + texts.add(ans.text); + } + return texts; + } +} diff --git a/lite/examples/bert_qa/android/lib_task_api/src/main/AndroidManifest.xml b/lite/examples/bert_qa/android/lib_task_api/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..d765e70b40c --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/src/main/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java b/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java new file mode 100644 index 00000000000..2d5f072276e --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/Answer.java @@ -0,0 +1,39 @@ +package org.tensorflow.lite.examples.bertqa.ml; + +/** QA Answer class. */ +public class Answer implements Comparable{ + public Pos pos; + public String text; + + public Answer(String text, Pos pos) { + this.text = text; + this.pos = pos; + } + + public Answer(String text, int start, int end, float logit) { + this(text, new Pos(start, end, logit)); + } + + @Override + public int compareTo(Answer other) { + return Float.compare(other.pos.logit, this.pos.logit); + } + + /** Position and related information from the model. */ + public static class Pos implements Comparable { + public int start; + public int end; + public float logit; + + public Pos(int start, int end, float logit) { + this.start = start; + this.end = end; + this.logit = logit; + } + + @Override + public int compareTo(Pos other) { + return Float.compare(other.logit, this.logit); + } + } +} diff --git a/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java b/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java new file mode 100644 index 00000000000..67df0779fae --- /dev/null +++ b/lite/examples/bert_qa/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/bertqa/ml/QaClient.java @@ -0,0 +1,64 @@ +package org.tensorflow.lite.examples.bertqa.ml; + +import android.content.Context; +import android.util.Log; + +import org.tensorflow.lite.task.text.qa.BertQuestionAnswerer; +import org.tensorflow.lite.task.text.qa.QaAnswer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Load TFLite model and create BertQuestionAnswerer instance. + */ +public class QaClient { + + private static final String TAG = "TaskApi"; + private static final String MODEL_PATH = "model.tflite"; + + private final Context context; + + public BertQuestionAnswerer answerer; + + public QaClient(Context context) { + this.context = context; + } + + /** + * Load TF Lite model. + */ + public void loadModel() { + try { + answerer = BertQuestionAnswerer.createFromFile(context, MODEL_PATH); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + answerer.close(); + answerer = null; + } + + + /** + * Run inference and predict the possible answers. + */ + public List predict(String questionToAsk, String contextOfTheQuestion) { + + List apiResult = answerer.answer(contextOfTheQuestion, questionToAsk); + List answers = new ArrayList<>(apiResult.size()); + for (int i = 0; i < apiResult.size(); i++){ + QaAnswer qaAnswer = apiResult.get(i); + answers.add(new Answer(qaAnswer.text, qaAnswer.pos.start, qaAnswer.pos.end, qaAnswer.pos.logit)); + } + Collections.sort(answers); + return answers; + } +} diff --git a/lite/examples/bert_qa/android/settings.gradle b/lite/examples/bert_qa/android/settings.gradle index 0745a61f07e..2cc90a3ad1b 100644 --- a/lite/examples/bert_qa/android/settings.gradle +++ b/lite/examples/bert_qa/android/settings.gradle @@ -1,2 +1,4 @@ rootProject.name = 'TFLite BERT Q&A Demo App' include ':app' +include ':lib_task_api' +include ':lib_interpreter'