diff --git a/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml b/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml index 8f589375833..a939f15e825 100644 --- a/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml +++ b/lite/examples/object_detection/android/app/src/main/res/layout/tfe_od_layout_bottom_sheet.xml @@ -163,7 +163,7 @@ android:layout_width="match_parent" android:layout_height="wrap_content" android:orientation="horizontal" - android:visibility="gone"> + > + ## Overview This is an end-to-end example of movie review sentiment classification built -with TensorFlow 2.0 (Keras API), and trained on IMDB dataset. The demo app +with TensorFlow 2.0 (Keras API), and trained on [IMDB dataset](http://ai.stanford.edu/%7Eamaas/data/sentiment/) version 1.0. The demo app processes input movie review texts, and classifies its sentiment into negative (0) or positive (1). @@ -14,19 +16,23 @@ mobile app. ## Model See -[Text Classification with Movie Reviews](https://www.tensorflow.org/tutorials/keras/basic_text_classification) +[Text Classification with Movie Reviews](https://www.tensorflow.org/tutorials/keras/text_classification) for a step-by-step instruction of building a simple text classification model. -## Android app - -Follow the steps below to build and run the sample Android app. +## Build the demo using Android Studio + +### Prerequisites -### Requirements +* If you don't have already, install + [Android Studio](https://developer.android.com/studio/index.html), following + the instructions on the website. -* Android Studio 3.2 or later. Install instructions can be found on - [Android Studio](https://developer.android.com/studio/index.html) website. +* Android Studio 3.2 or later. + - Gradle 4.6 or higher. + - SDK Build Tools 29.0.2 or higher. -* An Android device or an Android emulator and with API level higher than 15. +* You need an Android device or Android emulator and Android development + environment with minimum API 21. ### Building @@ -53,34 +59,26 @@ Follow the steps below to build and run the sample Android app. * Click `Run` to run the demo app on your Android device. -#### Switch between inference solutions (Task library vs TFLite Interpreter) +## Build the demo using gradle (command line) + +### Building and Installing -This Text Classification Android reference app demonstrates two implementation -solutions: +* Use the following command to build a demo apk: -(1) -[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/nl_classification/android/lib_task_api) -that leverages the out-of-box API from the -[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/text_classifier); +``` +cd lite/examples/bert_qa/android # Folder for Android app. -(2) -[`lib_interpreter`](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/lib_interpreter) -that creates the custom inference pipleline using the -[TensorFlow Lite Interpreter Java API](https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_java). +./gradlew build +``` -The [`build.gradle`](app/build.gradle) inside `app` folder shows how to change -`flavorDimensions "tfliteInference"` to switch between the two solutions. +* Use the following command to install the apk onto your connected device: -Inside **Android Studio**, you can change the build variant to whichever one you -want to build and run—just go to `Build > Select Build Variant` and select one -from the drop-down menu. See -[configure product flavors in Android Studio](https://developer.android.com/studio/build/build-variants#product-flavors) -for more details. +``` +adb install app/build/outputs/apk/debug/app-debug.apk +``` -For gradle CLI, running `./gradlew build` can create APKs for both solutions -under `app/build/outputs/apk`. +## Assets folder -*Note: If you simply want the out-of-box API to run the app, we recommend -`lib_task_api`for inference. If you want to customize your own models and -control the detail of inputs and outputs, it might be easier to adapt your model -inputs and outputs by using `lib_interpreter`.* +_Do not delete the assets folder content_. If you explicitly deleted the files, +choose `Build -> Rebuild` to re-download the deleted model files into the assets +folder. diff --git a/lite/examples/text_classification/android/app/build.gradle b/lite/examples/text_classification/android/app/build.gradle index c2b8ff8c6dd..4f5e43a852d 100644 --- a/lite/examples/text_classification/android/app/build.gradle +++ b/lite/examples/text_classification/android/app/build.gradle @@ -1,12 +1,13 @@ apply plugin: 'com.android.application' +apply plugin: 'de.undercouch.download' android { - compileSdkVersion 28 - buildToolsVersion "29.0.0" + compileSdkVersion 29 + buildToolsVersion "29.0.2" defaultConfig { applicationId "org.tensorflow.lite.examples.textclassification" minSdkVersion 21 - targetSdkVersion 28 + targetSdkVersion 29 versionCode 1 versionName "1.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" @@ -29,27 +30,22 @@ android { includeAndroidResources = true } } - - flavorDimensions "tfliteInference" - productFlavors { - // The TFLite inference is built using the TFLite Java interpreter. - interpreter { - dimension "tfliteInference" - } - // Default: The TFLite inference is built using the TFLite Task library (high-level API). - taskApi { - getIsDefault().set(true) - dimension "tfliteInference" - } - } } +// Download the pre-trained model from the internet +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download_model.gradle' + dependencies { - interpreterImplementation project(":lib_interpreter") - taskApiImplementation project(":lib_task_api") + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation 'androidx.appcompat:appcompat:1.1.0' implementation 'androidx.constraintlayout:constraintlayout:1.1.3' implementation 'org.jetbrains:annotations:15.0' + implementation 'com.google.android.material:material:1.0.0' + + //Task Text Library dependency + implementation 'org.tensorflow:tensorflow-lite-task-text:0.2.0' testImplementation 'androidx.test:core:1.2.0' testImplementation 'junit:junit:4.12' diff --git a/lite/examples/text_classification/android/lib_interpreter/download_model.gradle b/lite/examples/text_classification/android/app/download_model.gradle similarity index 100% rename from lite/examples/text_classification/android/lib_interpreter/download_model.gradle rename to lite/examples/text_classification/android/app/download_model.gradle diff --git a/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java b/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java new file mode 100644 index 00000000000..9f7f3d34a4a --- /dev/null +++ b/lite/examples/text_classification/android/app/src/androidTest/java/org/tensorflow/lite/examples/textclassification/UnitTest.java @@ -0,0 +1,49 @@ +package org.tensorflow.lite.examples.textclassification; + + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.textclassification.ml.Result; +import org.tensorflow.lite.examples.textclassification.ml.TextClassificationClient; + +import static org.junit.Assert.*; + +import android.content.Context; +import android.support.test.InstrumentationRegistry; +import android.support.test.runner.AndroidJUnit4; + +/** Tests of {@link TextClassificationClient} */ +@RunWith(AndroidJUnit4.class) +public final class UnitTest { + private TextClassificationClient client; + private String api = "NLCLASSIFIER"; + + @Before + public void setUp() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + + client = new TextClassificationClient(appContext, api); + client.load(); + } +// +// @Test +// public void loadModelTest() { +// assertNotNull(client.classifier); +// } + + @Test + public void predictTest() { + Result positiveText = + client + .classify("This is an interesting film. My family and I all liked it very much.") + .get(0); + assertEquals("Positive", positiveText.getTitle()); + assertTrue(positiveText.getConfidence() > 0.55); + Result negativeText = + client.classify("This film cannot be worse. It is way too boring.").get(0); + assertEquals("Negative", negativeText.getTitle()); + assertTrue(negativeText.getConfidence() > 0.6); + } +} \ No newline at end of file diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java index c30c5ad7e26..74c0675022b 100644 --- a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/MainActivity.java @@ -16,100 +16,224 @@ package org.tensorflow.lite.examples.textclassification; +import android.os.Build; import android.os.Bundle; import android.os.Handler; + +import androidx.annotation.NonNull; import androidx.appcompat.app.AppCompatActivity; + import android.util.Log; import android.view.View; +import android.view.ViewTreeObserver; +import android.widget.AdapterView; import android.widget.Button; import android.widget.EditText; +import android.widget.ImageView; +import android.widget.LinearLayout; import android.widget.ScrollView; +import android.widget.Spinner; import android.widget.TextView; +import android.widget.Toast; + +import com.google.android.material.bottomsheet.BottomSheetBehavior; + +import org.tensorflow.lite.examples.textclassification.ml.Result; +import org.tensorflow.lite.examples.textclassification.ml.TextClassificationClient; + import java.util.List; -import org.tensorflow.lite.examples.textclassification.client.Result; -import org.tensorflow.lite.examples.textclassification.client.TextClassificationClient; - -/** The main activity to provide interactions with users. */ -public class MainActivity extends AppCompatActivity { - private static final String TAG = "TextClassificationDemo"; - - private TextClassificationClient client; - - private TextView resultTextView; - private EditText inputEditText; - private Handler handler; - private ScrollView scrollView; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.tfe_tc_activity_main); - Log.v(TAG, "onCreate"); - - client = new TextClassificationClient(getApplicationContext()); - handler = new Handler(); - Button classifyButton = findViewById(R.id.button); - classifyButton.setOnClickListener( - (View v) -> { - classify(inputEditText.getText().toString()); - }); - resultTextView = findViewById(R.id.result_text_view); - inputEditText = findViewById(R.id.input_text); - scrollView = findViewById(R.id.scroll_view); - } - - @Override - protected void onStart() { - super.onStart(); - Log.v(TAG, "onStart"); - handler.post( - () -> { - client.load(); - }); - } - - @Override - protected void onStop() { - super.onStop(); - Log.v(TAG, "onStop"); - handler.post( - () -> { - client.unload(); - }); - } - - /** Send input text to TextClassificationClient and get the classify messages. */ - private void classify(final String text) { - handler.post( - () -> { - // Run text classification with TF Lite. - List results = client.classify(text); - - // Show classification result on screen - showResult(text, results); - }); - } - - /** Show classification result on the screen. */ - private void showResult(final String inputText, final List results) { - // Run on UI thread as we'll updating our app UI - runOnUiThread( - () -> { - String textToShow = "Input: " + inputText + "\nOutput:\n"; - for (int i = 0; i < results.size(); i++) { - Result result = results.get(i); - textToShow += String.format(" %s: %s\n", result.getTitle(), result.getConfidence()); - } - textToShow += "---------\n"; - - // Append the result to the UI. - resultTextView.append(textToShow); - - // Clear the input text. - inputEditText.getText().clear(); - - // Scroll to the bottom to show latest entry's classification result. - scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN)); - }); - } + +/** + * The main activity to provide interactions with users. + */ +public class MainActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener { + private static final String TAG = "TextClassificationDemo"; + + private TextClassificationClient client; + + private TextView resultTextView; + private EditText inputEditText; + private Handler handler; + private ScrollView scrollView; + private LinearLayout bottomSheetLayout; + private BottomSheetBehavior sheetBehavior; + private ImageView bottomSheetArrowImageView; + private Spinner apiSpinner; + private String api = "NLCLASSIFIER"; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.tfe_tc_activity_main); + Log.v(TAG, "onCreate"); + + client = new TextClassificationClient(getApplicationContext(), api); + handler = new Handler(); + Button classifyButton = findViewById(R.id.button); + classifyButton.setOnClickListener( + (View v) -> { + if (inputEditText.getText().toString().isEmpty()) { + Toast.makeText(this, "review is null", Toast.LENGTH_SHORT).show(); + }else { + classify(inputEditText.getText().toString()); + } + }); + resultTextView = findViewById(R.id.result_text_view); + inputEditText = findViewById(R.id.input_text); + scrollView = findViewById(R.id.scroll_view); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + apiSpinner = findViewById(R.id.api_spinner); + + apiSpinner.setOnItemSelectedListener(this); + + api = apiSpinner.getSelectedItem().toString().toUpperCase(); + + setupBottomSheet(); + + } + + /** + * Setup the Bottom Sheet + */ + private void setupBottomSheet() { + ViewTreeObserver vto = bottomSheetArrowImageView.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + bottomSheetArrowImageView.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + bottomSheetArrowImageView.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = bottomSheetArrowImageView.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + inputEditText.setEnabled(false); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + inputEditText.setEnabled(true); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + inputEditText.setEnabled(true); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) { + } + }); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.load(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unload(); + }); + } + + /** + * Send input text to TextClassificationClient and get the classify messages. + */ + private void classify(final String text) { + handler.post( + () -> { + // Run text classification with TF Lite. + List results = client.classify(text); + + // Show classification result on screen + showResult(text, results); + }); + } + + /** + * Show classification result on the screen. + */ + private void showResult(final String inputText, final List results) { + // Run on UI thread as we'll updating our app UI + runOnUiThread( + () -> { + String textToShow = "Input: " + inputText + "\nOutput:\n"; + for (int i = 0; i < results.size(); i++) { + Result result = results.get(i); + textToShow += String.format(" %s: %s\n", result.getTitle(), result.getConfidence()); + } + textToShow += "---------\n"; + + // Append the result to the UI. + resultTextView.append(textToShow); + + // Clear the input text. + inputEditText.getText().clear(); + + // Scroll to the bottom to show latest entry's classification result. + scrollView.post(() -> scrollView.fullScroll(View.FOCUS_DOWN)); + }); + } + + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + if (parent == apiSpinner) { + setApi(parent.getItemAtPosition(position).toString().toUpperCase()); + Toast.makeText(this, api, Toast.LENGTH_SHORT).show(); + } + } + + private void setApi(String api) { + if (this.api != api) { + this.api = api; + + recreateClassifier(); + } + + } + + private void recreateClassifier() { + client.unload(); + client = new TextClassificationClient(this, api); + client.load(); + + } + + @Override + public void onNothingSelected(AdapterView parent) { + + } } diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java new file mode 100644 index 00000000000..ff7f76bd47b --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/BertNLClassifierClient.java @@ -0,0 +1,62 @@ +package org.tensorflow.lite.examples.textclassification.ml; + +import android.content.Context; +import android.util.Log; + +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.text.nlclassifier.BertNLClassifier; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class BertNLClassifierClient { + private static final String TAG = "BertNLClassifierTaskApi"; + private static final String MODEL_PATH = "text_classification.tflite"; + + private final Context context; + + BertNLClassifier classifier; + + public BertNLClassifierClient(Context context) { + this.context = context; + } + + /** + * Load TF Lite model. + */ + public void load() { + try { + classifier = BertNLClassifier.createFromFile(context, MODEL_PATH); + Log.d(TAG, "load"); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + classifier.close(); + classifier = null; + Log.d(TAG, "unload"); + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List apiResults = classifier.classify(text); + List results = new ArrayList<>(apiResults.size()); + for (int i = 0; i < apiResults.size(); i++) { + Category category = apiResults.get(i); + results.add(new Result("" + i, category.getLabel(), category.getScore())); + } + Log.d(TAG, "classify"); + Collections.sort(results); + Log.d(TAG, results.toString()); + return results; + } +} diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java new file mode 100644 index 00000000000..b2dca97f4b0 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/NLClassifierClient.java @@ -0,0 +1,62 @@ +package org.tensorflow.lite.examples.textclassification.ml; + +import android.content.Context; +import android.util.Log; + +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.text.nlclassifier.NLClassifier; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class NLClassifierClient { + private static final String TAG = "NLClassifier_TaskApi"; + private static final String MODEL_PATH = "text_classification.tflite"; + + private final Context context; + + NLClassifier classifier; + + public NLClassifierClient(Context context) { + this.context = context; + } + + /** + * Load TF Lite model. + */ + public void load() { + try { + classifier = NLClassifier.createFromFile(context, MODEL_PATH); + Log.d(TAG, "load"); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + } + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + classifier.close(); + classifier = null; + Log.d(TAG, "unload"); + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List apiResults = classifier.classify(text); + List results = new ArrayList<>(apiResults.size()); + for (int i = 0; i < apiResults.size(); i++) { + Category category = apiResults.get(i); + results.add(new Result("" + i, category.getLabel(), category.getScore())); + } + Log.d(TAG, "classify"); + Collections.sort(results); + Log.d(TAG, results.toString()); + return results; + } +} diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java new file mode 100644 index 00000000000..a681bc21f6a --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/Result.java @@ -0,0 +1,79 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.textclassification.ml; + +/** + * An immutable result returned by a TextClassifier describing what was classified. + */ +public class Result implements Comparable { + /** + * A unique identifier for what has been classified. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** + * Display name for the result. + */ + private final String title; + + /** + * A sortable score for how good the result is relative to others. Higher should be better. + */ + private final Float confidence; + + public Result(final String id, final String title, final Float confidence) { + this.id = id; + this.title = title; + this.confidence = confidence; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + return resultString.trim(); + } + + @Override + public int compareTo(Result o) { + return o.confidence.compareTo(confidence); + } +} diff --git a/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java new file mode 100644 index 00000000000..99af61306f8 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/java/org/tensorflow/lite/examples/textclassification/ml/TextClassificationClient.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.textclassification.ml; + +import android.content.Context; + +import java.util.ArrayList; +import java.util.List; + +/** + * Load TfLite model and provide predictions with task api. + */ +public class TextClassificationClient { + private static final String NLCLASSIFIER = "NLCLASSIFIER"; + private static final String BertNLCLASSIFIER = "BERTNLCLASSIFIER"; + + private final Context context; + private String api; + + NLClassifierClient nlClassifierClient; + BertNLClassifierClient bertNLClassifierClient; + + public TextClassificationClient(Context context, String api) { + this.context = context; + this.api = api; + + nlClassifierClient = new NLClassifierClient(context); + bertNLClassifierClient = new BertNLClassifierClient(context); + } + + /** + * Load TF Lite model. + */ + public void load() { + if (api.equals(NLCLASSIFIER)) { + + nlClassifierClient.load(); + } else if (api.equals(BertNLCLASSIFIER)) { + + bertNLClassifierClient.load(); + } + + } + + /** + * Free up resources as the client is no longer needed. + */ + public void unload() { + if (api.equals(NLCLASSIFIER)) { + + nlClassifierClient.unload(); + } else if (api.equals(BertNLCLASSIFIER)) { + + bertNLClassifierClient.unload(); + } + } + + /** + * Classify an input string and returns the classification results. + */ + public List classify(String text) { + List results = new ArrayList<>(); + if (api.equals(NLCLASSIFIER)) { + + results = nlClassifierClient.classify(text); + } else if (api.equals(BertNLCLASSIFIER)) { + + results = bertNLClassifierClient.classify(text); + } + return results; + } +} diff --git a/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 00000000000..70f4b24e350 --- /dev/null +++ b/lite/examples/text_classification/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml b/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml index fab8c4fc115..47daa31bd64 100644 --- a/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml +++ b/lite/examples/text_classification/android/app/src/main/res/layout/tfe_tc_activity_main.xml @@ -1,45 +1,54 @@ - - - - - - - - -