From 7ca6ab1139c497c9a6cf4797425a8486ad00d9d1 Mon Sep 17 00:00:00 2001 From: Sunit Roy <2703iamsry@gmail.com> Date: Wed, 30 Jun 2021 18:26:12 +0530 Subject: [PATCH 1/4] Updated Text Classification Android to use Task Library --- .../text_classification/android/README.md | 2 +- .../android/app/build.gradle | 31 ++- .../download_model.gradle | 0 .../examples/textclassification/UnitTest.java | 46 ++++ .../textclassification/MainActivity.java | 175 ++++++++------- .../examples/textclassification/Result.java | 79 +++++++ .../TextClassificationClient.java | 77 +++++++ .../android/lib_interpreter/build.gradle | 54 ----- .../lib_interpreter/proguard-rules.pro | 21 -- .../src/main/AndroidManifest.xml | 4 - .../textclassification/client/Result.java | 73 ------- .../client/TextClassificationClient.java | 205 ------------------ .../textclassification/client/UnitTest.java | 87 -------- .../android/lib_task_api/build.gradle | 51 ----- .../lib_task_api/download_model.gradle | 7 - .../android/lib_task_api/proguard-rules.pro | 21 -- .../lib_task_api/src/main/AndroidManifest.xml | 4 - .../textclassification/client/Result.java | 73 ------- .../client/TextClassificationClient.java | 64 ------ .../textclassification/client/UnitTest.java | 58 ----- .../android/settings.gradle | 3 +- 21 files changed, 308 insertions(+), 827 deletions(-) rename lite/examples/text_classification/android/{lib_interpreter => app}/download_model.gradle (100%) create mode 100644 lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java create mode 100644 lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java create mode 100644 lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java delete mode 100644 lite/examples/text_classification/android/lib_interpreter/build.gradle delete mode 100644 lite/examples/text_classification/android/lib_interpreter/proguard-rules.pro delete mode 100644 lite/examples/text_classification/android/lib_interpreter/src/main/AndroidManifest.xml delete mode 100644 lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java delete mode 100644 lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java delete mode 100644 lite/examples/text_classification/android/lib_interpreter/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java delete mode 100644 lite/examples/text_classification/android/lib_task_api/build.gradle delete mode 100644 lite/examples/text_classification/android/lib_task_api/download_model.gradle delete mode 100644 lite/examples/text_classification/android/lib_task_api/proguard-rules.pro delete mode 100644 lite/examples/text_classification/android/lib_task_api/src/main/AndroidManifest.xml delete mode 100644 lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java delete mode 100644 lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java delete mode 100644 lite/examples/text_classification/android/lib_task_api/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java diff --git a/lite/examples/text_classification/android/README.md b/lite/examples/text_classification/android/README.md index f03066a83c2..bd1597646c6 100644 --- a/lite/examples/text_classification/android/README.md +++ b/lite/examples/text_classification/android/README.md @@ -26,7 +26,7 @@ Follow the steps below to build and run the sample Android app. * Android Studio 3.2 or later. Install instructions can be found on [Android Studio](https://developer.android.com/studio/index.html) website. -* An Android device or an Android emulator and with API level higher than 15. +* An Android device or an Android emulator and with API level higher than 21. ### Building diff --git a/lite/examples/text_classification/android/app/build.gradle b/lite/examples/text_classification/android/app/build.gradle index c2b8ff8c6dd..f7f831ceddc 100644 --- a/lite/examples/text_classification/android/app/build.gradle +++ b/lite/examples/text_classification/android/app/build.gradle @@ -1,12 +1,13 @@ apply plugin: 'com.android.application' +apply plugin: 'de.undercouch.download' android { - compileSdkVersion 28 - buildToolsVersion "29.0.0" + compileSdkVersion 29 + buildToolsVersion "29.0.2" defaultConfig { applicationId "org.tensorflow.lite.examples.textclassification" minSdkVersion 21 - targetSdkVersion 28 + targetSdkVersion 29 versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" @@ -29,28 +30,22 @@ android { includeAndroidResources = true } } - - flavorDimensions "tfliteInference" - productFlavors { - // The TFLite inference is built using the TFLite Java interpreter. - interpreter { - dimension "tfliteInference" - } - // Default: The TFLite inference is built using the TFLite Task library (high-level API). - taskApi { - getIsDefault().set(true) - dimension "tfliteInference" - } - } } +// Download the pre-trained model from the internet +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download_model.gradle' + dependencies { - interpreterImplementation project(":lib_interpreter") - taskApiImplementation project(":lib_task_api") + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation 'androidx.appcompat:appcompat:1.1.0' implementation 'androidx.constraintlayout:constraintlayout:1.1.3' implementation 'org.jetbrains:annotations:15.0' + //Task Text Library dependency + implementation 'org.tensorflow:tensorflow-lite-task-text:0.2.0' + testImplementation 'androidx.test:core:1.2.0' testImplementation 'junit:junit:4.12' testImplementation 'org.robolectric:robolectric:4.3' diff --git a/lite/examples/text_classification/android/lib_interpreter/download_model.gradle b/lite/examples/text_classification/android/app/download_model.gradle similarity index 100% rename from lite/examples/text_classification/android/lib_interpreter/download_model.gradle rename to lite/examples/text_classification/android/app/download_model.gradle diff --git a/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java b/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java new file mode 100644 index 00000000000..f1444a9c268 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java @@ -0,0 +1,46 @@ +package org.tensorflow.lite.examples.textclassification; + + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +import android.content.Context; +import android.support.test.InstrumentationRegistry; +import android.support.test.runner.AndroidJUnit4; + +/** Tests of {@link TextClassificationClient} */ +@RunWith(AndroidJUnit4.class) +public final class UnitTest { + private TextClassificationClient client; + + @Before + public void setUp() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + + client = new TextClassificationClient(appContext); + client.load(); + } + + @Test + public void loadModelTest() { + assertNotNull(client.classifier); + } + + @Test + public void predictTest() { + Result positiveText = + client + .classify("This is an interesting film. My family and I all liked it very much.") + .get(0); + assertEquals("Positive", positiveText.getTitle()); + assertTrue(positiveText.getConfidence() > 0.55); + Result negativeText = + client.classify("This film cannot be worse. It is way too boring.").get(0); + assertEquals("Negative", negativeText.getTitle()); + assertTrue(negativeText.getConfidence() > 0.6); + } +} \ No newline at end of file diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java index c30c5ad7e26..b46b9636406 100644 --- a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java @@ -18,98 +18,105 @@ import android.os.Bundle; import android.os.Handler; + import androidx.appcompat.app.AppCompatActivity; + import android.util.Log; import android.view.View; import android.widget.Button; import android.widget.EditText; import android.widget.ScrollView; import android.widget.TextView; + import java.util.List; -import org.tensorflow.lite.examples.textclassification.client.Result; -import org.tensorflow.lite.examples.textclassification.client.TextClassificationClient; -/** The main activity to provide interactions with users. */ +/** + * The main activity to provide interactions with users. + */ public class MainActivity extends AppCompatActivity { - private static final String TAG = "TextClassificationDemo"; - - private TextClassificationClient client; - - private TextView resultTextView; - private EditText inputEditText; - private Handler handler; - private ScrollView scrollView; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.tfe_tc_activity_main); - Log.v(TAG, "onCreate"); - - client = new TextClassificationClient(getApplicationContext()); - handler = new Handler(); - Button classifyButton = findViewById(R.id.button); - classifyButton.setOnClickListener( - (View v) -> { - classify(inputEditText.getText().toString()); - }); - resultTextView = findViewById(R.id.result_text_view); - inputEditText = findViewById(R.id.input_text); - scrollView = findViewById(R.id.scroll_view); - } - - @Override - protected void onStart() { - super.onStart(); - Log.v(TAG, "onStart"); - handler.post( - () -> { - client.load(); - }); - } - - @Override - protected void onStop() { - super.onStop(); - Log.v(TAG, "onStop"); - handler.post( - () -> { - client.unload(); - }); - } - - /** Send input text to TextClassificationClient and get the classify messages. */ - private void classify(final String text) { - handler.post( - () -> { - // Run text classification with TF Lite. - List results = client.classify(text); - - // Show classification result on screen - showResult(text, results); - }); - } - - /** Show classification result on the screen. */ - private void showResult(final String inputText, final List results) { - // Run on UI thread as we'll updating our app UI - runOnUiThread( - () -> { - String textToShow = "Input: " + inputText + "\nOutput:\n"; - for (int i = 0; i < results.size(); i++) { - Result result = results.get(i); - textToShow += String.format(" %s: %s\n", result.getTitle(), result.getConfidence()); - } - textToShow += "---------\n"; - - // Append the result to the UI. - resultTextView.append(textToShow); - - // Clear the input text. - inputEditText.getText().clear(); - - // Scroll to the bottom to show latest entry's classification result. - scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN)); - }); - } + private static final String TAG = "TextClassificationDemo"; + + private TextClassificationClient client; + + private TextView resultTextView; + private EditText inputEditText; + private Handler handler; + private ScrollView scrollView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.tfe_tc_activity_main); + Log.v(TAG, "onCreate"); + + client = new TextClassificationClient(getApplicationContext()); + handler = new Handler(); + Button classifyButton = findViewById(R.id.button); + classifyButton.setOnClickListener( + (View v) -> { + classify(inputEditText.getText().toString()); + }); + resultTextView = findViewById(R.id.result_text_view); + inputEditText = findViewById(R.id.input_text); + scrollView = findViewById(R.id.scroll_view); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.load(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unload(); + }); + } + + /** + * Send input text to TextClassificationClient and get the classify messages. + */ + private void classify(final String text) { + handler.post( + () -> { + // Run text classification with TF Lite. + List results = client.classify(text); + + // Show classification result on screen + showResult(text, results); + }); + } + + /** + * Show classification result on the screen. + */ + private void showResult(final String inputText, final List results) { + // Run on UI thread as we'll updating our app UI + runOnUiThread( + () -> { + String textToShow = "Input: " + inputText + "\nOutput:\n"; + for (int i = 0; i < results.size(); i++) { + Result result = results.get(i); + textToShow += String.format(" %s: %s\n", result.getTitle(), result.getConfidence()); + } + textToShow += "---------\n"; + + // Append the result to the UI. + resultTextView.append(textToShow); + + // Clear the input text. + inputEditText.getText().clear(); + + // Scroll to the bottom to show latest entry's classification result. + scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN)); + }); + } } diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java new file mode 100644 index 00000000000..6b902b2cc52 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.textclassification; + +/** + * An immutable result returned by a TextClassifier describing what was classified. + */ +public class Result implements Comparable { + /** + * A unique identifier for what has been classified. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** + * Display name for the result. + */ + private final String title; + + /** + * A sortable score for how good the result is relative to others. Higher should be better. + */ + private final Float confidence; + + public Result(final String id, final String title, final Float confidence) { + this.id = id; + this.title = title; + this.confidence = confidence; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + return resultString.trim(); + } + + @Override + public int compareTo(Result o) { + return o.confidence.compareTo(confidence); + } +} diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java new file mode 100644 index 00000000000..14a056728b1 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.textclassification; + +import android.content.Context; +import android.util.Log; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.text.nlclassifier.NLClassifier; + +/** + * Load TfLite model and provide predictions with task api. + */ +public class TextClassificationClient { + private static final String TAG = "TaskApi"; + private static final String MODEL_PATH = "text_classification.tflite"; + + private final Context context; + + NLClassifier classifier; + + public TextClassificationClient(Context context) { + this.context = context; + } + + /** + * Load TF Lite model. + */ + public void load() { + try { + classifier = NLClassifier.createFromFile(context, MODEL_PATH); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + classifier.close(); + classifier = null; + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List apiResults = classifier.classify(text); + List results = new ArrayList<>(apiResults.size()); + for (int i = 0; i < apiResults.size(); i++) { + Category category = apiResults.get(i); + results.add(new Result("" + i, category.getLabel(), category.getScore())); + } + Collections.sort(results); + return results; + } +} diff --git a/lite/examples/text_classification/android/lib_interpreter/build.gradle b/lite/examples/text_classification/android/lib_interpreter/build.gradle deleted file mode 100644 index 34c44b0ac58..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/build.gradle +++ /dev/null @@ -1,54 +0,0 @@ -apply plugin: 'com.android.library' -apply plugin: 'de.undercouch.download' - -android { - compileSdkVersion 28 - buildToolsVersion "29.0.0" - defaultConfig { - minSdkVersion 21 - targetSdkVersion 28 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' - } - } - compileOptions { - sourceCompatibility = '1.8' - targetCompatibility = '1.8' - } - aaptOptions { - noCompress "tflite" - } - testOptions { - unitTests { - includeAndroidResources = true - } - } -} - -// Download the pre-trained model from the internet -project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' -apply from:'download_model.gradle' - -dependencies { - implementation fileTree(dir: 'libs', include: ['*.jar']) - - implementation 'org.tensorflow:tensorflow-lite:2.2.0' - implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0-rc1' - implementation 'org.jetbrains:annotations:15.0' - - testImplementation 'androidx.test:core:1.2.0' - testImplementation 'junit:junit:4.12' - testImplementation 'org.robolectric:robolectric:4.3' - androidTestImplementation 'com.android.support.test:runner:1.0.2' - androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' -} - -project(':lib_interpreter').tasks.withType(Test) { - enabled = false -} diff --git a/lite/examples/text_classification/android/lib_interpreter/proguard-rules.pro b/lite/examples/text_classification/android/lib_interpreter/proguard-rules.pro deleted file mode 100644 index f1b424510da..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile diff --git a/lite/examples/text_classification/android/lib_interpreter/src/main/AndroidManifest.xml b/lite/examples/text_classification/android/lib_interpreter/src/main/AndroidManifest.xml deleted file mode 100644 index c708063b918..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/src/main/AndroidManifest.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - diff --git a/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java b/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java deleted file mode 100644 index 3664615ce08..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -/** An immutable result returned by a TextClassifier describing what was classified. */ -public class Result implements Comparable { - /** - * A unique identifier for what has been classified. Specific to the class, not the instance of - * the object. - */ - private final String id; - - /** Display name for the result. */ - private final String title; - - /** A sortable score for how good the result is relative to others. Higher should be better. */ - private final Float confidence; - - public Result(final String id, final String title, final Float confidence) { - this.id = id; - this.title = title; - this.confidence = confidence; - } - - public String getId() { - return id; - } - - public String getTitle() { - return title; - } - - public Float getConfidence() { - return confidence; - } - - @Override - public String toString() { - String resultString = ""; - if (id != null) { - resultString += "[" + id + "] "; - } - - if (title != null) { - resultString += title + " "; - } - - if (confidence != null) { - resultString += String.format("(%.1f%%) ", confidence * 100.0f); - } - - return resultString.trim(); - } - - @Override - public int compareTo(Result o) { - return o.confidence.compareTo(confidence); - } -} diff --git a/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java b/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java deleted file mode 100644 index a48da7f5c2e..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -import android.content.Context; -import android.content.res.AssetFileDescriptor; -import android.content.res.AssetManager; -import android.util.Log; -import java.io.BufferedReader; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.nio.ByteBuffer; -import java.nio.MappedByteBuffer; -import java.nio.channels.FileChannel; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.PriorityQueue; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.support.metadata.MetadataExtractor; - -/** Interface to load TfLite model and provide predictions. */ -public class TextClassificationClient { - private static final String TAG = "Interpreter"; - - private static final int SENTENCE_LEN = 256; // The maximum length of an input sentence. - // Simple delimiter to split words. - private static final String SIMPLE_SPACE_OR_PUNCTUATION = " |\\,|\\.|\\!|\\?|\n"; - private static final String MODEL_PATH = "text_classification.tflite"; - /* - * Reserved values in ImdbDataSet dic: - * dic[""] = 0 used for padding - * dic[""] = 1 mark for the start of a sentence - * dic[""] = 2 mark for unknown words (OOV) - */ - private static final String START = ""; - private static final String PAD = ""; - private static final String UNKNOWN = ""; - - /** Number of results to show in the UI. */ - private static final int MAX_RESULTS = 3; - - private final Context context; - private final Map dic = new HashMap<>(); - private final List labels = new ArrayList<>(); - private Interpreter tflite; - - public TextClassificationClient(Context context) { - this.context = context; - } - - /** Load the TF Lite model and dictionary so that the client can start classifying text. */ - public void load() { - loadModel(); - } - - /** Load TF Lite model. */ - private synchronized void loadModel() { - try { - // Load the TF Lite model - ByteBuffer buffer = loadModelFile(this.context.getAssets(), MODEL_PATH); - tflite = new Interpreter(buffer); - Log.v(TAG, "TFLite model loaded."); - - // Use metadata extractor to extract the dictionary and label files. - MetadataExtractor metadataExtractor = new MetadataExtractor(buffer); - - // Extract and load the dictionary file. - InputStream dictionaryFile = metadataExtractor.getAssociatedFile("vocab.txt"); - loadDictionaryFile(dictionaryFile); - Log.v(TAG, "Dictionary loaded."); - - // Extract and load the label file. - InputStream labelFile = metadataExtractor.getAssociatedFile("labels.txt"); - loadLabelFile(labelFile); - Log.v(TAG, "Labels loaded."); - - } catch (IOException ex) { - Log.e(TAG, "Error loading TF Lite model.\n", ex); - } - } - - /** Free up resources as the client is no longer needed. */ - public synchronized void unload() { - tflite.close(); - dic.clear(); - labels.clear(); - } - - /** Classify an input string and returns the classification results. */ - public synchronized List classify(String text) { - // Pre-prosessing. - int[][] input = tokenizeInputText(text); - - // Run inference. - Log.v(TAG, "Classifying text with TF Lite..."); - float[][] output = new float[1][labels.size()]; - tflite.run(input, output); - - // Find the best classifications. - PriorityQueue pq = - new PriorityQueue<>( - MAX_RESULTS, (lhs, rhs) -> Float.compare(rhs.getConfidence(), lhs.getConfidence())); - for (int i = 0; i < labels.size(); i++) { - pq.add(new Result("" + i, labels.get(i), output[0][i])); - } - final ArrayList results = new ArrayList<>(); - while (!pq.isEmpty()) { - results.add(pq.poll()); - } - - Collections.sort(results); - // Return the probability of each class. - return results; - } - - /** Load TF Lite model from assets. */ - private static MappedByteBuffer loadModelFile(AssetManager assetManager, String modelPath) - throws IOException { - try (AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath); - FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { - FileChannel fileChannel = inputStream.getChannel(); - long startOffset = fileDescriptor.getStartOffset(); - long declaredLength = fileDescriptor.getDeclaredLength(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - } - - /** Load dictionary from model file. */ - private void loadLabelFile(InputStream ins) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(ins)); - // Each line in the label file is a label. - while (reader.ready()) { - labels.add(reader.readLine()); - } - } - - /** Load labels from model file. */ - private void loadDictionaryFile(InputStream ins) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(ins)); - // Each line in the dictionary has two columns. - // First column is a word, and the second is the index of this word. - while (reader.ready()) { - List line = Arrays.asList(reader.readLine().split(" ")); - if (line.size() < 2) { - continue; - } - dic.put(line.get(0), Integer.parseInt(line.get(1))); - } - } - - /** Pre-prosessing: tokenize and map the input words into a float array. */ - int[][] tokenizeInputText(String text) { - int[] tmp = new int[SENTENCE_LEN]; - List array = Arrays.asList(text.split(SIMPLE_SPACE_OR_PUNCTUATION)); - - int index = 0; - // Prepend if it is in vocabulary file. - if (dic.containsKey(START)) { - tmp[index++] = dic.get(START); - } - - for (String word : array) { - if (index >= SENTENCE_LEN) { - break; - } - tmp[index++] = dic.containsKey(word) ? dic.get(word) : (int) dic.get(UNKNOWN); - } - // Padding and wrapping. - Arrays.fill(tmp, index, SENTENCE_LEN - 1, (int) dic.get(PAD)); - int[][] ans = {tmp}; - return ans; - } - - Map getDic() { - return this.dic; - } - - Interpreter getTflite() { - return this.tflite; - } - - List getLabels() { - return this.labels; - } -} diff --git a/lite/examples/text_classification/android/lib_interpreter/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java b/lite/examples/text_classification/android/lib_interpreter/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java deleted file mode 100644 index 790f29e2d77..00000000000 --- a/lite/examples/text_classification/android/lib_interpreter/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -import androidx.test.core.app.ApplicationProvider; -import java.util.Arrays; -import java.util.List; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.robolectric.RobolectricTestRunner; - -/** Tests of {@link TextClassificationClient} */ -@RunWith(RobolectricTestRunner.class) -public final class UnitTest { - private TextClassificationClient client; - - @Before - public void setUp() { - client = new TextClassificationClient(ApplicationProvider.getApplicationContext()); - client.load(); - } - - @Test - public void loadModelTest() { - assertNotNull(client.getTflite()); - } - - @Test - public void loadDictinaryTest() { - assertEquals(0, (int) client.getDic().get("")); - assertEquals(1, (int) client.getDic().get("")); - assertEquals(2, (int) client.getDic().get("")); - assertEquals(3, (int) client.getDic().get("the")); - } - - @Test - public void loadLabelsTest() { - List labels = client.getLabels(); - assertEquals("Negative", labels.get(0)); - assertEquals("Positive", labels.get(1)); - } - - @Test - public void inputPreprocessingTest() { - int[][] clientOutput = client.tokenizeInputText("hello,world!"); - int[][] expectOutput = new int[1][256]; - Arrays.fill(expectOutput[0], 0, 255, 0); - expectOutput[0][0] = 1; // Index for - expectOutput[0][1] = 4845; // Index for "hello". - expectOutput[0][2] = 181; // Index for "world". - assertArrayEquals(expectOutput, clientOutput); - } - - @Test - public void predictTest() { - Result positiveText = - client - .classify("This is an interesting film. My family and I all liked it very much.") - .get(0); - assertEquals("Positive", positiveText.getTitle()); - assertTrue(positiveText.getConfidence() > 0.55); - Result negativeText = - client.classify("This film cannot be worse. It is way too boring.").get(0); - assertEquals("Negative", negativeText.getTitle()); - assertTrue(negativeText.getConfidence() > 0.6); - } -} diff --git a/lite/examples/text_classification/android/lib_task_api/build.gradle b/lite/examples/text_classification/android/lib_task_api/build.gradle deleted file mode 100644 index ac7bdea6dc6..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/build.gradle +++ /dev/null @@ -1,51 +0,0 @@ -apply plugin: 'com.android.library' -apply plugin: 'de.undercouch.download' - -android { - compileSdkVersion 28 - buildToolsVersion "29.0.0" - defaultConfig { - minSdkVersion 21 - targetSdkVersion 28 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' - } - } - compileOptions { - sourceCompatibility = '1.8' - targetCompatibility = '1.8' - } - aaptOptions { - noCompress "tflite" - } - testOptions { - unitTests { - includeAndroidResources = true - } - } -} - -// Download the pre-trained model from the internet -project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' -apply from:'download_model.gradle' - -dependencies { - implementation 'org.tensorflow:tensorflow-lite-task-text:0.1.0' - implementation 'org.jetbrains:annotations:15.0' - - testImplementation 'androidx.test:core:1.2.0' - testImplementation 'junit:junit:4.12' - testImplementation 'org.robolectric:robolectric:4.3' - androidTestImplementation 'com.android.support.test:runner:1.0.2' - androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2' -} - -project(':lib_task_api').tasks.withType(Test) { - enabled = false -} diff --git a/lite/examples/text_classification/android/lib_task_api/download_model.gradle b/lite/examples/text_classification/android/lib_task_api/download_model.gradle deleted file mode 100644 index b09fe763e48..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/download_model.gradle +++ /dev/null @@ -1,7 +0,0 @@ -task downloadModelFile(type: Download) { - src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite' - dest project.ext.ASSET_DIR + '/text_classification.tflite' - overwrite true -} - -preBuild.dependsOn downloadModelFile \ No newline at end of file diff --git a/lite/examples/text_classification/android/lib_task_api/proguard-rules.pro b/lite/examples/text_classification/android/lib_task_api/proguard-rules.pro deleted file mode 100644 index f1b424510da..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile diff --git a/lite/examples/text_classification/android/lib_task_api/src/main/AndroidManifest.xml b/lite/examples/text_classification/android/lib_task_api/src/main/AndroidManifest.xml deleted file mode 100644 index c708063b918..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/src/main/AndroidManifest.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - diff --git a/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java b/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java deleted file mode 100644 index 3664615ce08..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/Result.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -/** An immutable result returned by a TextClassifier describing what was classified. */ -public class Result implements Comparable { - /** - * A unique identifier for what has been classified. Specific to the class, not the instance of - * the object. - */ - private final String id; - - /** Display name for the result. */ - private final String title; - - /** A sortable score for how good the result is relative to others. Higher should be better. */ - private final Float confidence; - - public Result(final String id, final String title, final Float confidence) { - this.id = id; - this.title = title; - this.confidence = confidence; - } - - public String getId() { - return id; - } - - public String getTitle() { - return title; - } - - public Float getConfidence() { - return confidence; - } - - @Override - public String toString() { - String resultString = ""; - if (id != null) { - resultString += "[" + id + "] "; - } - - if (title != null) { - resultString += title + " "; - } - - if (confidence != null) { - resultString += String.format("(%.1f%%) ", confidence * 100.0f); - } - - return resultString.trim(); - } - - @Override - public int compareTo(Result o) { - return o.confidence.compareTo(confidence); - } -} diff --git a/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java b/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java deleted file mode 100644 index 550acbb6a10..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -import android.content.Context; -import android.util.Log; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.tensorflow.lite.support.label.Category; -import org.tensorflow.lite.task.text.nlclassifier.NLClassifier; - -/** Load TfLite model and provide predictions with task api. */ -public class TextClassificationClient { - private static final String TAG = "TaskApi"; - private static final String MODEL_PATH = "text_classification.tflite"; - - private final Context context; - - NLClassifier classifier; - - public TextClassificationClient(Context context) { - this.context = context; - } - - public void load() { - try { - classifier = NLClassifier.createFromFile(context, MODEL_PATH); - } catch (IOException e) { - Log.e(TAG, e.getMessage()); - } - } - - public void unload() { - classifier.close(); - classifier = null; - } - - public List classify(String text) { - List apiResults = classifier.classify(text); - List results = new ArrayList<>(apiResults.size()); - for (int i = 0; i < apiResults.size(); i++) { - Category category = apiResults.get(i); - results.add(new Result("" + i, category.getLabel(), category.getScore())); - } - Collections.sort(results); - return results; - } -} diff --git a/lite/examples/text_classification/android/lib_task_api/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java b/lite/examples/text_classification/android/lib_task_api/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java deleted file mode 100644 index c7bc61d032d..00000000000 --- a/lite/examples/text_classification/android/lib_task_api/src/test/java/org/tensorflow/lite/examples/textclassification/client/UnitTest.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification.client; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -import androidx.test.core.app.ApplicationProvider; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.robolectric.RobolectricTestRunner; - -/** Tests of {@link TextClassificationClient} */ -@RunWith(RobolectricTestRunner.class) -public final class UnitTest { - private TextClassificationClient client; - - @Before - public void setUp() { - client = new TextClassificationClient(ApplicationProvider.getApplicationContext()); - client.load(); - } - - @Test - public void loadModelTest() { - assertNotNull(client.classifier); - } - - @Test - public void predictTest() { - Result positiveText = - client - .classify("This is an interesting film. My family and I all liked it very much.") - .get(0); - assertEquals("Positive", positiveText.getTitle()); - assertTrue(positiveText.getConfidence() > 0.55); - Result negativeText = - client.classify("This film cannot be worse. It is way too boring.").get(0); - assertEquals("Negative", negativeText.getTitle()); - assertTrue(negativeText.getConfidence() > 0.6); - } -} diff --git a/lite/examples/text_classification/android/settings.gradle b/lite/examples/text_classification/android/settings.gradle index 9fd8456d711..c309508c131 100644 --- a/lite/examples/text_classification/android/settings.gradle +++ b/lite/examples/text_classification/android/settings.gradle @@ -1,4 +1,3 @@ rootProject.name = 'TFLite Text Classification Demo App' include ':app' -include ':lib_task_api' -include ':lib_interpreter' + From c996b58b1c0d8e5803f5cc025fd3d56898deebb5 Mon Sep 17 00:00:00 2001 From: Sunit Roy <2703iamsry@gmail.com> Date: Fri, 9 Jul 2021 22:16:27 +0530 Subject: [PATCH 2/4] Added spinner to switch between NLClassifier and BERT NL Classifier --- .../res/layout/tfe_od_layout_bottom_sheet.xml | 2 +- .../android/app/build.gradle | 1 + .../examples/textclassification/UnitTest.java | 15 ++- .../textclassification/MainActivity.java | 117 +++++++++++++++++- .../ml/BertNLClassifierClient.java | 62 ++++++++++ .../NLClassifierClient.java} | 37 ++---- .../textclassification/{ => ml}/Result.java | 2 +- .../ml/TextClassificationClient.java | 86 +++++++++++++ .../src/main/res/drawable/bottom_sheet_bg.xml | 9 ++ .../main/res/layout/tfe_tc_activity_main.xml | 83 +++++++------ .../res/layout/tfe_tc_layout_bottom_sheet.xml | 47 +++++++ .../app/src/main/res/values/dimens.xml | 2 + .../app/src/main/res/values/strings.xml | 12 +- 13 files changed, 399 insertions(+), 76 deletions(-) create mode 100644 lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java rename lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/{TextClassificationClient.java => ml/NLClassifierClient.java} (61%) rename lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/{ => ml}/Result.java (97%) create mode 100644 lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java create mode 100644 lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml create mode 100644 lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_layout_bottom_sheet.xml diff --git a/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml b/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml index 8f589375833..a939f15e825 100644 --- a/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml +++ b/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml @@ -163,7 +163,7 @@ android:layout_width="match_parent" android:layout_height="wrap_content" android:orientation="horizontal" - android:visibility="gone"> + > sheetBehavior; + private ImageView bottomSheetArrowImageView; + private Spinner apiSpinner; + private String api = "NLCLASSIFIER"; @Override protected void onCreate(Bundle savedInstanceState) { @@ -49,7 +67,7 @@ protected void onCreate(Bundle savedInstanceState) { setContentView(R.layout.tfe_tc_activity_main); Log.v(TAG, "onCreate"); - client = new TextClassificationClient(getApplicationContext()); + client = new TextClassificationClient(getApplicationContext(), api); handler = new Handler(); Button classifyButton = findViewById(R.id.button); classifyButton.setOnClickListener( @@ -59,6 +77,72 @@ protected void onCreate(Bundle savedInstanceState) { resultTextView = findViewById(R.id.result_text_view); inputEditText = findViewById(R.id.input_text); scrollView = findViewById(R.id.scroll_view); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + apiSpinner = findViewById(R.id.api_spinner); + + apiSpinner.setOnItemSelectedListener(this); + + api = apiSpinner.getSelectedItem().toString().toUpperCase(); + + setupBottomSheet(); + + } + + /** + * Setup the Bottom Sheet + */ + private void setupBottomSheet() { + ViewTreeObserver vto = bottomSheetArrowImageView.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + bottomSheetArrowImageView.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + bottomSheetArrowImageView.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = bottomSheetArrowImageView.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + inputEditText.setEnabled(false); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + inputEditText.setEnabled(true); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + inputEditText.setEnabled(true); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) { + } + }); } @Override @@ -119,4 +203,33 @@ private void showResult(final String inputText, final List results) { scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN)); }); } + + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + if (parent == apiSpinner) { + setApi(parent.getItemAtPosition(position).toString().toUpperCase()); + Toast.makeText(this, api, Toast.LENGTH_SHORT).show(); + } + } + + private void setApi(String api) { + if (this.api != api) { + this.api = api; + + recreateClassifier(); + } + + } + + private void recreateClassifier() { + client.unload(); + client = new TextClassificationClient(this, api); + client.load(); + + } + + @Override + public void onNothingSelected(AdapterView parent) { + + } } diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java new file mode 100644 index 00000000000..ff7f76bd47b --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java @@ -0,0 +1,62 @@ +package org.tensorflow.lite.examples.textclassification.ml; + +import android.content.Context; +import android.util.Log; + +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.text.nlclassifier.BertNLClassifier; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class BertNLClassifierClient { + private static final String TAG = "BertNLClassifierTaskApi"; + private static final String MODEL_PATH = "text_classification.tflite"; + + private final Context context; + + BertNLClassifier classifier; + + public BertNLClassifierClient(Context context) { + this.context = context; + } + + /** + * Load TF Lite model. + */ + public void load() { + try { + classifier = BertNLClassifier.createFromFile(context, MODEL_PATH); + Log.d(TAG, "load"); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + classifier.close(); + classifier = null; + Log.d(TAG, "unload"); + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List apiResults = classifier.classify(text); + List results = new ArrayList<>(apiResults.size()); + for (int i = 0; i < apiResults.size(); i++) { + Category category = apiResults.get(i); + results.add(new Result("" + i, category.getLabel(), category.getScore())); + } + Log.d(TAG, "classify"); + Collections.sort(results); + Log.d(TAG, results.toString()); + return results; + } +} diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java similarity index 61% rename from lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java rename to lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java index 14a056728b1..b2dca97f4b0 100644 --- a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java @@ -1,44 +1,25 @@ -/* - * Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tensorflow.lite.examples.textclassification; +package org.tensorflow.lite.examples.textclassification.ml; import android.content.Context; import android.util.Log; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.text.nlclassifier.NLClassifier; + import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import org.tensorflow.lite.support.label.Category; -import org.tensorflow.lite.task.text.nlclassifier.NLClassifier; - -/** - * Load TfLite model and provide predictions with task api. - */ -public class TextClassificationClient { - private static final String TAG = "TaskApi"; +public class NLClassifierClient { + private static final String TAG = "NLClassifier_TaskApi"; private static final String MODEL_PATH = "text_classification.tflite"; private final Context context; NLClassifier classifier; - public TextClassificationClient(Context context) { + public NLClassifierClient(Context context) { this.context = context; } @@ -48,6 +29,7 @@ public TextClassificationClient(Context context) { public void load() { try { classifier = NLClassifier.createFromFile(context, MODEL_PATH); + Log.d(TAG, "load"); } catch (IOException e) { Log.e(TAG, e.getMessage()); } @@ -59,6 +41,7 @@ public void load() { public void unload() { classifier.close(); classifier = null; + Log.d(TAG, "unload"); } /** @@ -71,7 +54,9 @@ public List classify(String text) { Category category = apiResults.get(i); results.add(new Result("" + i, category.getLabel(), category.getScore())); } + Log.d(TAG, "classify"); Collections.sort(results); + Log.d(TAG, results.toString()); return results; } } diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java similarity index 97% rename from lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java rename to lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java index 6b902b2cc52..a681bc21f6a 100644 --- a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/Result.java +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.tensorflow.lite.examples.textclassification; +package org.tensorflow.lite.examples.textclassification.ml; /** * An immutable result returned by a TextClassifier describing what was classified. diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java new file mode 100644 index 00000000000..99af61306f8 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.textclassification.ml; + +import android.content.Context; + +import java.util.ArrayList; +import java.util.List; + +/** + * Load TfLite model and provide predictions with task api. + */ +public class TextClassificationClient { + private static final String NLCLASSIFIER = "NLCLASSIFIER"; + private static final String BertNLCLASSIFIER = "BERTNLCLASSIFIER"; + + private final Context context; + private String api; + + NLClassifierClient nlClassifierClient; + BertNLClassifierClient bertNLClassifierClient; + + public TextClassificationClient(Context context, String api) { + this.context = context; + this.api = api; + + nlClassifierClient = new NLClassifierClient(context); + bertNLClassifierClient = new BertNLClassifierClient(context); + } + + /** + * Load TF Lite model. + */ + public void load() { + if (api.equals(NLCLASSIFIER)) { + + nlClassifierClient.load(); + } else if (api.equals(BertNLCLASSIFIER)) { + + bertNLClassifierClient.load(); + } + + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + if (api.equals(NLCLASSIFIER)) { + + nlClassifierClient.unload(); + } else if (api.equals(BertNLCLASSIFIER)) { + + bertNLClassifierClient.unload(); + } + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List results = new ArrayList<>(); + if (api.equals(NLCLASSIFIER)) { + + results = nlClassifierClient.classify(text); + } else if (api.equals(BertNLCLASSIFIER)) { + + results = bertNLClassifierClient.classify(text); + } + return results; + } +} diff --git a/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 00000000000..70f4b24e350 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml b/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml index fab8c4fc115..47daa31bd64 100644 --- a/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml +++ b/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml @@ -1,45 +1,54 @@ - - - - - - - - -