Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
package com.google_mlkit_commons;

import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.Tasks;
import com.google.mlkit.common.model.DownloadConditions;
import com.google.mlkit.common.model.RemoteModel;
import com.google.mlkit.common.model.RemoteModelManager;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;

Expand All @@ -20,16 +12,25 @@ public class GenericModelManager {
private static final String DELETE = "delete";
private static final String CHECK = "check";

public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance();
public interface CheckModelIsDownloadedCallback {
void onModelDownloaded(Boolean isDownloaded);

void onError(Exception e);
}

//To avoid downloading models in the main thread as they are around 20MB and may crash the app.
private final ExecutorService executorService = Executors.newCachedThreadPool();
public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance();

public void manageModel(final RemoteModel model, final MethodCall call, final MethodChannel.Result result) {
String task = call.argument("task");

if (task == null) {
result.notImplemented();
return;
}

switch (task) {
case DOWNLOAD:
boolean isWifiReqRequired = call.argument("wifi");
boolean isWifiReqRequired = Boolean.TRUE.equals(call.argument("wifi"));
DownloadConditions downloadConditions;
if (isWifiReqRequired)
downloadConditions = new DownloadConditions.Builder().requireWifi().build();
Expand All @@ -41,52 +42,77 @@ public void manageModel(final RemoteModel model, final MethodCall call, final Me
deleteModel(model, result);
break;
case CHECK:
Boolean downloaded = isModelDownloaded(model);
if (downloaded != null) result.success(downloaded);
else result.error("error", null, null);
isModelDownloaded(
model,
new CheckModelIsDownloadedCallback() {
@Override
public void onModelDownloaded(Boolean isDownloaded) {
result.success(isDownloaded);
}

@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
}
);
break;
default:
result.notImplemented();
}
}

public void downloadModel(RemoteModel remoteModel, DownloadConditions downloadConditions, final MethodChannel.Result result) {
if (isModelDownloaded(remoteModel)) {
result.success("success");
return;
}
remoteModelManager.download(remoteModel, downloadConditions).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null));
}
isModelDownloaded(
remoteModel,
new CheckModelIsDownloadedCallback() {
@Override
public void onModelDownloaded(Boolean isDownloaded) {
if (isDownloaded) {
result.success("success");
return;
}

public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) {
if (!isModelDownloaded(remoteModel)) {
result.success("success");
return;
}
remoteModelManager.deleteDownloadedModel(remoteModel).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null));
}
remoteModelManager.download(remoteModel, downloadConditions)
.addOnSuccessListener(aVoid -> result.success("success"))
.addOnFailureListener(e -> result.error("error", e.toString(), null));
}

public Boolean isModelDownloaded(RemoteModel model) {
IsModelDownloaded myCallable = new IsModelDownloaded(remoteModelManager.isModelDownloaded(model));
Future<Boolean> taskResult = executorService.submit(myCallable);
try {
return taskResult.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
return null;
@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
}
);
}
}

class IsModelDownloaded implements Callable<Boolean> {
final Task<Boolean> booleanTask;
public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) {
isModelDownloaded(remoteModel, new CheckModelIsDownloadedCallback() {
@Override
public void onModelDownloaded(Boolean isDownloaded) {
if (!isDownloaded) {
result.success("success");
return;
}
remoteModelManager.deleteDownloadedModel(remoteModel)
.addOnSuccessListener(aVoid -> result.success("success"))
.addOnFailureListener(e -> result.error("error", e.toString(), null));
}

public IsModelDownloaded(Task<Boolean> booleanTask) {
this.booleanTask = booleanTask;
@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
});
}

@Override
public Boolean call() throws Exception {
return Tasks.await(booleanTask);
public void isModelDownloaded(RemoteModel model, CheckModelIsDownloadedCallback callback) {
try {
remoteModelManager.isModelDownloaded(model)
.addOnFailureListener(e -> callback.onError(e))
.addOnSuccessListener(isDownloaded -> callback.onModelDownloaded(isDownloaded));
} catch (Exception e) {
callback.onError(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,71 +54,85 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result)
DigitalInkRecognitionModel model = getModel(tag, result);
if (model == null)
return;
if (!genericModelManager.isModelDownloaded(model)) {
result.error("Model Error", "Model has not been downloaded yet ", null);
return;
}

String id = call.argument("id");
com.google.mlkit.vision.digitalink.DigitalInkRecognizer recognizer = instances.get(id);
if (recognizer == null) {
recognizer = DigitalInkRecognition.getClient(DigitalInkRecognizerOptions.builder(model).build());
instances.put(id, recognizer);
}

Map<String, Object> inkMap = call.argument("ink");
List<Map<String, Object>> strokeList = (List<Map<String, Object>>) inkMap.get("strokes");
Ink.Builder inkBuilder = Ink.builder();
for (final Map<String, Object> strokeMap : strokeList) {
Ink.Stroke.Builder strokeBuilder = Ink.Stroke.builder();
List<Map<String, Object>> pointsList = (List<Map<String, Object>>) strokeMap.get("points");
for (final Map<String, Object> point : pointsList) {
float x = (float) (double) point.get("x");
float y = (float) (double) point.get("y");
Object t0 = point.get("t");
long t;
if (t0 instanceof Integer) {
t = (int) t0;
} else {
t = (long) t0;
genericModelManager.isModelDownloaded(
model,
new GenericModelManager.CheckModelIsDownloadedCallback() {
@Override
public void onModelDownloaded(Boolean isDownloaded) {
if (!isDownloaded) {
result.error("Model Error", "Model has not been downloaded yet ", null);
return;
}

String id = call.argument("id");
com.google.mlkit.vision.digitalink.DigitalInkRecognizer recognizer = instances.get(id);
if (recognizer == null) {
recognizer = DigitalInkRecognition.getClient(DigitalInkRecognizerOptions.builder(model).build());
instances.put(id, recognizer);
}

Map<String, Object> inkMap = call.argument("ink");
List<Map<String, Object>> strokeList = (List<Map<String, Object>>) inkMap.get("strokes");
Ink.Builder inkBuilder = Ink.builder();
for (final Map<String, Object> strokeMap : strokeList) {
Ink.Stroke.Builder strokeBuilder = Ink.Stroke.builder();
List<Map<String, Object>> pointsList = (List<Map<String, Object>>) strokeMap.get("points");
for (final Map<String, Object> point : pointsList) {
float x = (float) (double) point.get("x");
float y = (float) (double) point.get("y");
Object t0 = point.get("t");
long t;
if (t0 instanceof Integer) {
t = (int) t0;
} else {
t = (long) t0;
}
Ink.Point strokePoint = Ink.Point.create(x, y, t);
strokeBuilder.addPoint(strokePoint);
}
inkBuilder.addStroke(strokeBuilder.build());
}
Ink ink = inkBuilder.build();

RecognitionContext context = null;
Map<String, Object> contextMap = call.argument("context");
if (contextMap != null) {
RecognitionContext.Builder builder = RecognitionContext.builder();
String preContext = (String) contextMap.get("preContext");
if (preContext != null) {
builder.setPreContext(preContext);
} else {
builder.setPreContext("");
}

Map<String, Object> writingAreaMap = (Map<String, Object>) contextMap.get("writingArea");
if (writingAreaMap != null) {
float width = (float) (double) writingAreaMap.get("width");
float height = (float) (double) writingAreaMap.get("height");
builder.setWritingArea(new WritingArea(width, height));
}

context = builder.build();
}

if (context != null) {
recognizer.recognize(ink, context)
.addOnSuccessListener(recognitionResult -> process(recognitionResult, result))
.addOnFailureListener(e -> result.error("recognition Error", e.toString(), null));
} else {
recognizer.recognize(ink)
.addOnSuccessListener(recognitionResult -> process(recognitionResult, result))
.addOnFailureListener(e -> result.error("recognition Error", e.toString(), null));
}
}

@Override
public void onError(Exception e) {
result.error("error", e.toString(), null);
}
}
Ink.Point strokePoint = Ink.Point.create(x, y, t);
strokeBuilder.addPoint(strokePoint);
}
inkBuilder.addStroke(strokeBuilder.build());
}
Ink ink = inkBuilder.build();

RecognitionContext context = null;
Map<String, Object> contextMap = call.argument("context");
if (contextMap != null) {
RecognitionContext.Builder builder = RecognitionContext.builder();
String preContext = (String) contextMap.get("preContext");
if (preContext != null) {
builder.setPreContext(preContext);
} else {
builder.setPreContext("");
}

Map<String, Object> writingAreaMap = (Map<String, Object>) contextMap.get("writingArea");
if (writingAreaMap != null) {
float width = (float) (double) writingAreaMap.get("width");
float height = (float) (double) writingAreaMap.get("height");
builder.setWritingArea(new WritingArea(width, height));
}

context = builder.build();
}

if (context != null) {
recognizer.recognize(ink, context)
.addOnSuccessListener(recognitionResult -> process(recognitionResult, result))
.addOnFailureListener(e -> result.error("recognition Error", e.toString(), null));
} else {
recognizer.recognize(ink)
.addOnSuccessListener(recognitionResult -> process(recognitionResult, result))
.addOnFailureListener(e -> result.error("recognition Error", e.toString(), null));
}
);
}

private void process(RecognitionResult recognitionResult, final MethodChannel.Result result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,53 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result)
CustomImageLabelerOptions labelerOptions = getLocalOptions(options);
imageLabeler = ImageLabeling.getClient(labelerOptions);
} else if (type.equals("remote")) {
CustomImageLabelerOptions labelerOptions = getRemoteOptions(options);
if (labelerOptions == null) {
result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet");
return;
}
imageLabeler = ImageLabeling.getClient(labelerOptions);
float confidenceThreshold = (float) (double) options.get("confidenceThreshold");
int maxCount = (int) options.get("maxCount");
String name = (String) options.get("modelName");

FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build();
CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build();

genericModelManager.isModelDownloaded(
remoteModel,
new GenericModelManager.CheckModelIsDownloadedCallback() {
@Override
public void onModelDownloaded(Boolean isDownloaded) {
if (!isDownloaded) {
result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet");
return;
}

CustomImageLabelerOptions labelerOptions = new CustomImageLabelerOptions.Builder(remoteModel)
.setConfidenceThreshold(confidenceThreshold)
.setMaxResultCount(maxCount)
.build();

ImageLabeling.getClient(labelerOptions).process(inputImage)
.addOnSuccessListener(imageLabels -> {
List<Map<String, Object>> labels = new ArrayList<>(imageLabels.size());
for (ImageLabel label : imageLabels) {
Map<String, Object> labelData = new HashMap<>();
labelData.put("text", label.getText());
labelData.put("confidence", label.getConfidence());
labelData.put("index", label.getIndex());
labels.add(labelData);
}

result.success(labels);
})
.addOnFailureListener(e -> result.error("ImageLabelDetectorError", e.toString(), null));
;
}

@Override
public void onError(Exception e) {
result.error("Error", e.getMessage(), e);
}
}
);

return;
} else {
String error = "Invalid model type: " + type;
result.error(type, error, error);
Expand Down Expand Up @@ -131,24 +172,6 @@ private CustomImageLabelerOptions getLocalOptions(Map<String, Object> labelerOpt
.build();
}

//Options for labeler to work with custom model.
private CustomImageLabelerOptions getRemoteOptions(Map<String, Object> labelerOptions) {
float confidenceThreshold = (float) (double) labelerOptions.get("confidenceThreshold");
int maxCount = (int) labelerOptions.get("maxCount");
String name = (String) labelerOptions.get("modelName");

FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build();
CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build();
if (!genericModelManager.isModelDownloaded(remoteModel)) {
return null;
}

return new CustomImageLabelerOptions.Builder(remoteModel)
.setConfidenceThreshold(confidenceThreshold)
.setMaxResultCount(maxCount)
.build();
}

private void closeDetector(MethodCall call) {
String id = call.argument("id");
ImageLabeler imageLabeler = instances.get(id);
Expand Down
Loading