Skip to content

Commit 9fa7ec0

Browse files
author
Kalyan Reddy
committed
Update custom model codelab to use a locally bundled model as well as cloud model
1 parent fdd6ec4 commit 9fa7ec0

File tree

5 files changed

+31
-21
lines changed

5 files changed

+31
-21
lines changed

mlkit-android/custom-model/final/app/build.gradle

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ android {
1616
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
1717
}
1818
}
19+
aaptOptions {
20+
noCompress "tflite"
21+
}
1922
}
2023

2124
dependencies {
@@ -25,6 +28,8 @@ dependencies {
2528
testImplementation 'junit:junit:4.12'
2629
androidTestImplementation 'com.android.support.test:runner:1.0.2'
2730
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
31+
32+
implementation 'com.google.firebase:firebase-core:15.0.1'
2833
implementation 'com.google.firebase:firebase-ml-model-interpreter:15.0.0'
2934
}
3035
apply plugin: 'com.google.gms.google-services'

mlkit-android/custom-model/final/app/src/main/java/com/google/firebase/codelab/mlkit_custommodel/MainActivity.java

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import android.content.res.AssetManager;
2020
import android.graphics.Bitmap;
2121
import android.graphics.BitmapFactory;
22+
import android.support.annotation.NonNull;
2223
import android.support.v7.app.AppCompatActivity;
2324
import android.os.Bundle;
2425
import android.util.Log;
@@ -32,6 +33,7 @@
3233
import android.widget.Toast;
3334

3435
import com.google.android.gms.tasks.Continuation;
36+
import com.google.android.gms.tasks.OnFailureListener;
3537
import com.google.android.gms.tasks.Task;
3638
import com.google.firebase.ml.common.FirebaseMLException;
3739
import com.google.firebase.ml.custom.FirebaseModelDataType;
@@ -42,6 +44,7 @@
4244
import com.google.firebase.ml.custom.FirebaseModelOptions;
4345
import com.google.firebase.ml.custom.FirebaseModelOutputs;
4446
import com.google.firebase.ml.custom.model.FirebaseCloudModelSource;
47+
import com.google.firebase.ml.custom.model.FirebaseLocalModelSource;
4548
import com.google.firebase.ml.custom.model.FirebaseModelDownloadConditions;
4649

4750
import java.io.BufferedReader;
@@ -68,13 +71,13 @@ public class MainActivity extends AppCompatActivity implements AdapterView.OnIte
6871
private Integer mImageMaxWidth;
6972
// Max height (portrait mode)
7073
private Integer mImageMaxHeight;
71-
private boolean mIsLandScape;
7274
private final String[] mFilePaths =
7375
new String[]{"mountain.jpg", "tennis.jpg"};
7476
/**
7577
* Name of the model file hosted with Firebase.
7678
*/
7779
private static final String HOSTED_MODEL_NAME = "mobilenet_v1_224_quant";
80+
private static final String LOCAL_MODEL_ASSET = "mobilenet_v1.0_224_quant.tflite";
7881
/**
7982
* Name of the label file stored in Assets.
8083
*/
@@ -144,7 +147,7 @@ public void onClick(View v) {
144147
});
145148

146149
int[] inputDims = {DIM_BATCH_SIZE, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y, DIM_PIXEL_SIZE};
147-
int[] outputDims = {1, mLabelList.size()};
150+
int[] outputDims = {DIM_BATCH_SIZE, mLabelList.size()};
148151
try {
149152
mDataOptions =
150153
new FirebaseModelInputOutputOptions.Builder()
@@ -155,6 +158,9 @@ public void onClick(View v) {
155158
.Builder()
156159
.requireWifi()
157160
.build();
161+
FirebaseLocalModelSource localModelSource =
162+
new FirebaseLocalModelSource.Builder("asset")
163+
.setAssetFilePath(LOCAL_MODEL_ASSET).build();
158164

159165
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder
160166
(HOSTED_MODEL_NAME)
@@ -165,10 +171,12 @@ public void onClick(View v) {
165171
// for updates
166172
.build();
167173
FirebaseModelManager manager = FirebaseModelManager.getInstance();
174+
manager.registerLocalModelSource(localModelSource);
168175
manager.registerCloudModelSource(cloudSource);
169176
FirebaseModelOptions modelOptions =
170177
new FirebaseModelOptions.Builder()
171178
.setCloudModelName(HOSTED_MODEL_NAME)
179+
.setLocalModelName("asset")
172180
.build();
173181
mInterpreter = FirebaseModelInterpreter.getInstance(modelOptions);
174182
} catch (FirebaseMLException e) {
@@ -191,6 +199,13 @@ private void runModelInference() {
191199
// Here's where the magic happens!!
192200
mInterpreter
193201
.run(inputs, mDataOptions)
202+
.addOnFailureListener(new OnFailureListener() {
203+
@Override
204+
public void onFailure(@NonNull Exception e) {
205+
e.printStackTrace();
206+
showToast("Error running model inference");
207+
}
208+
})
194209
.continueWith(
195210
new Continuation<FirebaseModelOutputs, List<String>>() {
196211
@Override
@@ -337,15 +352,9 @@ public static Bitmap getBitmapFromAsset(Context context, String filePath) {
337352
private Integer getImageMaxWidth() {
338353
if (mImageMaxWidth == null) {
339354
// Calculate the max width in portrait mode. This is done lazily since we need to
340-
// wait for
341-
// a UI layout pass to get the right values. So delay it to first time image
355+
// wait for a UI layout pass to get the right values. So delay it to first time image
342356
// rendering time.
343-
if (mIsLandScape) {
344-
mImageMaxWidth =
345-
mImageView.getHeight();
346-
} else {
347-
mImageMaxWidth = mImageView.getWidth();
348-
}
357+
mImageMaxWidth = mImageView.getWidth();
349358
}
350359

351360
return mImageMaxWidth;
@@ -356,15 +365,10 @@ private Integer getImageMaxWidth() {
356365
private Integer getImageMaxHeight() {
357366
if (mImageMaxHeight == null) {
358367
// Calculate the max width in portrait mode. This is done lazily since we need to
359-
// wait for
360-
// a UI layout pass to get the right values. So delay it to first time image
368+
// wait for a UI layout pass to get the right values. So delay it to first time image
361369
// rendering time.
362-
if (mIsLandScape) {
363-
mImageMaxHeight = mImageView.getWidth();
364-
} else {
365-
mImageMaxHeight =
366-
mImageView.getHeight();
367-
}
370+
mImageMaxHeight =
371+
mImageView.getHeight();
368372
}
369373

370374
return mImageMaxHeight;
@@ -376,8 +380,8 @@ private Pair getTargetedWidthHeight() {
376380
int targetHeight;
377381
int maxWidthForPortraitMode = getImageMaxWidth();
378382
int maxHeightForPortraitMode = getImageMaxHeight();
379-
targetWidth = mIsLandScape ? maxHeightForPortraitMode : maxWidthForPortraitMode;
380-
targetHeight = mIsLandScape ? maxWidthForPortraitMode : maxHeightForPortraitMode;
383+
targetWidth = maxWidthForPortraitMode;
384+
targetHeight = maxHeightForPortraitMode;
381385
return new Pair<>(targetWidth, targetHeight);
382386
}
383387
}

mlkit-android/custom-model/final/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Top-level build file where you can add configuration options common to all sub-projects/modules.
22

33
buildscript {
4-
4+
55
repositories {
66
google()
77
jcenter()

mlkit-android/text-recognition/final/app/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies {
2626
androidTestImplementation 'com.android.support.test:runner:1.0.2'
2727
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
2828

29+
implementation 'com.google.firebase:firebase-core:15.0.1'
2930
implementation 'com.google.firebase:firebase-ml-vision:15.0.0'
3031
}
3132
apply plugin: 'com.google.gms.google-services'

0 commit comments

Comments
 (0)