Skip to content

Commit

Permalink
Publish sample app for acceleration service.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537066927
  • Loading branch information
hehhjiang authored and copybara-github committed Jun 1, 2023
1 parent aa6bb6a commit 6c25dea
Show file tree
Hide file tree
Showing 30 changed files with 1,696 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<h2>Most important links!</h2>

* [Community examples](https://devlibrary.withgoogle.com)
* [Community examples](./community)
* [Course materials](./courses/udacity_deep_learning) for the [Deep Learning](https://www.udacity.com/course/deep-learning--ud730) class on Udacity

If you are looking to learn TensorFlow, don't miss the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
apply plugin: 'com.android.application'
apply plugin: 'com.google.android.gms.strict-version-matcher-plugin'


android {
compileSdk 31

defaultConfig {
applicationId "org.tensorflow.lite.examples.accelerationservice"
minSdk 21
targetSdk 31
versionCode 1
versionName "1.0"

testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}

buildTypes {
all {
proguardFiles 'proguard-rules.pro'
}
release {
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}

aaptOptions {
noCompress "bin"
noCompress "tflite"
}
}


dependencies {
implementation "com.google.android.gms:play-services-tflite-acceleration-service:16.0.0-beta01"

implementation "com.google.android.gms:play-services-tflite-java:16.1.0"
implementation "com.google.android.gms:play-services-tflite-support:16.1.0"
implementation "com.google.android.gms:play-services-tflite-gpu:16.2.0"

implementation "com.google.android.gms:play-services-tasks:18.0.2"
implementation "androidx.appcompat:appcompat:1.4.1"

// https://mvnrepository.com/artifact/com.google.errorprone/error_prone_annotations
implementation group: 'com.google.errorprone', name: 'error_prone_annotations', version: '2.18.0'

androidTestImplementation "androidx.test:rules:1.1.0"
androidTestImplementation "androidx.test:runner:1.1.0"
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation "com.google.truth:truth:1.1.3"
}

// Download default models; if you wish to use your own models then
// place them in the "assets" directory and comment out this line.
project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
apply from:'download.gradle'
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
def modelFloatDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz"
def localCacheFloat = "build/intermediates/mobilenet_v1_1.0_224.tgz"
def targetFolder = "src/main/assets"
def addDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/acceleration_service/add.tflite"
def addTargetFile = "src/main/assets/add.tflite"

task downloadAddModel(type: DownloadUrlTask) {
doFirst {
println "Downloading ${addDownloadUrl}"
}
sourceUrl = "${addDownloadUrl}"
target = file("${addTargetFile}")
}


task downloadModelFloat(type: DownloadUrlTask) {
doFirst {
println "Downloading ${modelFloatDownloadUrl}"
}
sourceUrl = "${modelFloatDownloadUrl}"
target = file("${localCacheFloat}")
}

task unzipModelFloat(type: Copy, dependsOn: 'downloadModelFloat') {
doFirst {
println "Unzipping ${localCacheFloat}"
}
from tarTree("${localCacheFloat}")
into "${targetFolder}"
}

task cleanUnusedFiles(type: Delete, dependsOn: ['unzipModelFloat']) {
delete fileTree("${targetFolder}").matching {
include "*.pb"
include "*.ckpt.*"
include "*.pbtxt.*"
include "*.meta"
}
}


// Ensure the model file is downloaded and extracted before every build
preBuild.dependsOn downloadAddModel
preBuild.dependsOn cleanUnusedFiles

class DownloadUrlTask extends DefaultTask {
@Input
String sourceUrl

@OutputFile
File target

@TaskAction
void download() {
ant.get(src: sourceUrl, dest: target)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

-keepclasseswithmembers class android.support.** { *; }
-keepclasseswithmembers class androidx.** { *; }
-keepclasseswithmembers class com.google.android.gms.tasks.** { *; }
-keepclasseswithmembers class com.google.android.gms.tflite.acceleration.** { *; }
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2023 The TensorFlow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tensorflow.lite.examples.accelerationservice;

import static com.google.common.truth.Truth.assertThat;

import android.content.Context;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.rules.ActivityScenarioRule;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
import com.google.android.gms.tasks.Tasks;
import com.google.android.gms.tflite.acceleration.AccelerationConfig;
import com.google.android.gms.tflite.acceleration.CpuAccelerationConfig;
import com.google.android.gms.tflite.acceleration.CustomValidationConfig;
import com.google.android.gms.tflite.acceleration.CustomValidationConfig.AccuracyValidator;
import com.google.android.gms.tflite.acceleration.ValidationConfig;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.tensorflow.lite.examples.accelerationservice.logger.Logger;
import org.tensorflow.lite.examples.accelerationservice.model.AssetModel;
import org.tensorflow.lite.examples.accelerationservice.model.AssetModelFactory;
import org.tensorflow.lite.examples.accelerationservice.model.AssetModelFactory.ModelType;
import org.tensorflow.lite.examples.accelerationservice.validator.MeanSquaredErrorValidator;

/**
* Instrumented test, which will execute on an Android device. The test will run the interpreter
* using CPU config, and then check if the correct inference output is produced.
*
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
*/
@RunWith(AndroidJUnit4.class)
public class CustomValidationTest {

private final AccelerationConfig accelerationConfig = new CpuAccelerationConfig.Builder().build();
private final Executor executor = Executors.newSingleThreadExecutor();

private AssetModelFactory assetModelFactory;
private AccuracyValidator validator;

@Rule
public ActivityScenarioRule<MainActivity> scenarioRule =
new ActivityScenarioRule<>(MainActivity.class);

@Before
public void setUp() throws ExecutionException, InterruptedException {
Context context = ApplicationProvider.getApplicationContext();
Logger logger = new NoopLogger();
validator = new MeanSquaredErrorValidator(logger, MainActivity.MSE_THRESHOLD);
assetModelFactory =
Tasks.await(
getMainActivity()
.onSuccessTask(
activity -> Tasks.forResult(new AssetModelFactory(context, executor, logger))));
}

@Test
public void cpuCustomValidationOnPlainAdditionModel_succeeds()
throws ExecutionException, InterruptedException {
AssetModel assetModel = Tasks.await(assetModelFactory.load(ModelType.PLAIN_ADDITION));
assertThat(assetModel.getModel()).isNotNull();
assertThat(Tasks.await(runScenario(assetModel))).isTrue();
}

@Test
public void cpuCustomValidationOnMobileNetV1Model_succeeds()
throws ExecutionException, InterruptedException {
AssetModel assetModel = Tasks.await(assetModelFactory.load(ModelType.MOBILENET_V1));
assertThat(assetModel.getModel()).isNotNull();
assertThat(Tasks.await(runScenario(assetModel))).isTrue();
}

private Task<Boolean> runScenario(AssetModel assetModel) {
return getMainActivity()
.onSuccessTask(
activity -> {
ValidationConfig validationConfig =
new CustomValidationConfig.Builder()
.setGoldenInputs(assetModel.getInputs())
.setAccuracyValidator(validator)
.setBatchSize(assetModel.getBatchSize())
.build();
return activity.runScenario(
executor, assetModel, accelerationConfig, validationConfig);
});
}

private Task<MainActivity> getMainActivity() {
TaskCompletionSource<MainActivity> taskCompletionSource = new TaskCompletionSource<>();
scenarioRule.getScenario().onActivity(taskCompletionSource::setResult);
return taskCompletionSource.getTask();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2023 The TensorFlow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tensorflow.lite.examples.accelerationservice;

import org.tensorflow.lite.examples.accelerationservice.logger.Logger;

/** Logs no output. Used in instrumentation tests. */
class NoopLogger implements Logger {

@Override
public void error(String msg, Exception e) {}

@Override
public void info(String msg) {}

@Override
public void clear() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.examples.accelerationservice">

<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:supportsRtl="true"
android:theme="@style/AppTheme"
android:taskAffinity="">
<activity android:name="org.tensorflow.lite.examples.accelerationservice.MainActivity" android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>

<queries>
<package android:name="com.google.android.gms.policy_tflite_dynamite_dynamite" />
</queries>
</manifest>
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6c25dea

Please sign in to comment.