Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
package com.google_mlkit_commons;

import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.Tasks;
import com.google.mlkit.common.model.DownloadConditions;
import com.google.mlkit.common.model.RemoteModel;
import com.google.mlkit.common.model.RemoteModelManager;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;

Expand All @@ -20,13 +12,22 @@ public class GenericModelManager {
private static final String DELETE = "delete";
private static final String CHECK = "check";

public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance();
public interface CheckModelIsDownloadedCallback {
void onCheckResult(Boolean isDownloaded);

void onError(Exception e);
}

//To avoid downloading models in the main thread as they are around 20MB and may crash the app.
private final ExecutorService executorService = Executors.newCachedThreadPool();
public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance();

public void manageModel(final RemoteModel model, final MethodCall call, final MethodChannel.Result result) {
String task = call.argument("task");

if (task == null) {
result.notImplemented();
return;
}

switch (task) {
case DOWNLOAD:
boolean isWifiReqRequired = call.argument("wifi");
Expand All @@ -41,52 +42,77 @@ public void manageModel(final RemoteModel model, final MethodCall call, final Me
deleteModel(model, result);
break;
case CHECK:
Boolean downloaded = isModelDownloaded(model);
if (downloaded != null) result.success(downloaded);
else result.error("error", null, null);
isModelDownloaded(
model,
new CheckModelIsDownloadedCallback() {
@Override
public void onCheckResult(Boolean isDownloaded) {
result.success(isDownloaded);
}

@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
}
);
break;
default:
result.notImplemented();
}
}

public void downloadModel(RemoteModel remoteModel, DownloadConditions downloadConditions, final MethodChannel.Result result) {
if (isModelDownloaded(remoteModel)) {
result.success("success");
return;
}
remoteModelManager.download(remoteModel, downloadConditions).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null));
}
isModelDownloaded(
remoteModel,
new CheckModelIsDownloadedCallback() {
@Override
public void onCheckResult(Boolean isDownloaded) {
if (isDownloaded) {
result.success("success");
return;
}

public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) {
if (!isModelDownloaded(remoteModel)) {
result.success("success");
return;
}
remoteModelManager.deleteDownloadedModel(remoteModel).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null));
}
remoteModelManager.download(remoteModel, downloadConditions)
.addOnSuccessListener(aVoid -> result.success("success"))
.addOnFailureListener(e -> result.error("error", e.toString(), null));
}

public Boolean isModelDownloaded(RemoteModel model) {
IsModelDownloaded myCallable = new IsModelDownloaded(remoteModelManager.isModelDownloaded(model));
Future<Boolean> taskResult = executorService.submit(myCallable);
try {
return taskResult.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
return null;
@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
}
);
}
}

class IsModelDownloaded implements Callable<Boolean> {
final Task<Boolean> booleanTask;
public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) {
isModelDownloaded(remoteModel, new CheckModelIsDownloadedCallback() {
@Override
public void onCheckResult(Boolean isDownloaded) {
if (!isDownloaded) {
result.success("success");
return;
}
remoteModelManager.deleteDownloadedModel(remoteModel)
.addOnSuccessListener(aVoid -> result.success("success"))
.addOnFailureListener(e -> result.error("error", e.toString(), null));
}

public IsModelDownloaded(Task<Boolean> booleanTask) {
this.booleanTask = booleanTask;
@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
});
}

@Override
public Boolean call() throws Exception {
return Tasks.await(booleanTask);
public void isModelDownloaded(RemoteModel model, CheckModelIsDownloadedCallback callback) {
try {
remoteModelManager.isModelDownloaded(model)
.addOnFailureListener(callback::onError)
.addOnSuccessListener(callback::onCheckResult);
} catch (Exception e) {
callback.onError(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,33 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result)
DigitalInkRecognitionModel model = getModel(tag, result);
if (model == null)
return;
if (!genericModelManager.isModelDownloaded(model)) {
result.error("Model Error", "Model has not been downloaded yet ", null);
return;
}

genericModelManager.isModelDownloaded(
model,
new GenericModelManager.CheckModelIsDownloadedCallback() {
@Override
public void onCheckResult(Boolean isDownloaded) {
if (!isDownloaded) {
result.error("Model Error", "Model has not been downloaded yet ", null);
return;
}

handleInkDetectionIfModelDownloaded(call, result, model);
}

@Override
public void onError(Exception e) {
result.error("Model download check failed", e.toString(), e);
}
}
);
}

private void handleInkDetectionIfModelDownloaded(
MethodCall call,
final MethodChannel.Result result,
DigitalInkRecognitionModel model
) {
String id = call.argument("id");
com.google.mlkit.vision.digitalink.DigitalInkRecognizer recognizer = instances.get(id);
if (recognizer == null) {
Expand Down Expand Up @@ -164,4 +186,4 @@ private DigitalInkRecognitionModel getModel(String tag, final MethodChannel.Resu
}
return DigitalInkRecognitionModel.builder(modelIdentifier).build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,43 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result)
CustomImageLabelerOptions labelerOptions = getLocalOptions(options);
imageLabeler = ImageLabeling.getClient(labelerOptions);
} else if (type.equals("remote")) {
CustomImageLabelerOptions labelerOptions = getRemoteOptions(options);
if (labelerOptions == null) {
result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet");
return;
}
imageLabeler = ImageLabeling.getClient(labelerOptions);
float confidenceThreshold = (float) (double) options.get("confidenceThreshold");
int maxCount = (int) options.get("maxCount");
String name = (String) options.get("modelName");

FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build();
CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build();

genericModelManager.isModelDownloaded(
remoteModel,
new GenericModelManager.CheckModelIsDownloadedCallback() {
@Override
public void onCheckResult(Boolean isDownloaded) {
if (!isDownloaded) {
result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet");
return;
}

startImageLabelDetector(
ImageLabeling.getClient(
new CustomImageLabelerOptions.Builder(remoteModel)
.setConfidenceThreshold(confidenceThreshold)
.setMaxResultCount(maxCount)
.build()
),
inputImage,
result
);
}

@Override
public void onError(Exception e) {
result.error("Model download check failed", e.getMessage(), e);
}
}
);

return;
} else {
String error = "Invalid model type: " + type;
result.error(type, error, error);
Expand All @@ -93,6 +124,10 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result)
instances.put(id, imageLabeler);
}

startImageLabelDetector(imageLabeler, inputImage, result);
}

private void startImageLabelDetector(ImageLabeler imageLabeler, InputImage inputImage, MethodChannel.Result result) {
imageLabeler.process(inputImage)
.addOnSuccessListener(imageLabels -> {
List<Map<String, Object>> labels = new ArrayList<>(imageLabels.size());
Expand Down Expand Up @@ -131,24 +166,6 @@ private CustomImageLabelerOptions getLocalOptions(Map<String, Object> labelerOpt
.build();
}

//Options for labeler to work with custom model.
private CustomImageLabelerOptions getRemoteOptions(Map<String, Object> labelerOptions) {
float confidenceThreshold = (float) (double) labelerOptions.get("confidenceThreshold");
int maxCount = (int) labelerOptions.get("maxCount");
String name = (String) labelerOptions.get("modelName");

FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build();
CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build();
if (!genericModelManager.isModelDownloaded(remoteModel)) {
return null;
}

return new CustomImageLabelerOptions.Builder(remoteModel)
.setConfidenceThreshold(confidenceThreshold)
.setMaxResultCount(maxCount)
.build();
}

private void closeDetector(MethodCall call) {
String id = call.argument("id");
ImageLabeler imageLabeler = instances.get(id);
Expand Down
Loading