diff --git a/packages/google_mlkit_commons/android/src/main/java/com/google_mlkit_commons/GenericModelManager.java b/packages/google_mlkit_commons/android/src/main/java/com/google_mlkit_commons/GenericModelManager.java index 6cd19193..6b3a45e6 100644 --- a/packages/google_mlkit_commons/android/src/main/java/com/google_mlkit_commons/GenericModelManager.java +++ b/packages/google_mlkit_commons/android/src/main/java/com/google_mlkit_commons/GenericModelManager.java @@ -1,17 +1,9 @@ package com.google_mlkit_commons; -import com.google.android.gms.tasks.Task; -import com.google.android.gms.tasks.Tasks; import com.google.mlkit.common.model.DownloadConditions; import com.google.mlkit.common.model.RemoteModel; import com.google.mlkit.common.model.RemoteModelManager; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - import io.flutter.plugin.common.MethodCall; import io.flutter.plugin.common.MethodChannel; @@ -20,13 +12,22 @@ public class GenericModelManager { private static final String DELETE = "delete"; private static final String CHECK = "check"; - public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance(); + public interface CheckModelIsDownloadedCallback { + void onCheckResult(Boolean isDownloaded); + + void onError(Exception e); + } - //To avoid downloading models in the main thread as they are around 20MB and may crash the app. - private final ExecutorService executorService = Executors.newCachedThreadPool(); + public RemoteModelManager remoteModelManager = RemoteModelManager.getInstance(); public void manageModel(final RemoteModel model, final MethodCall call, final MethodChannel.Result result) { String task = call.argument("task"); + + if (task == null) { + result.notImplemented(); + return; + } + switch (task) { case DOWNLOAD: boolean isWifiReqRequired = call.argument("wifi"); @@ -41,9 +42,20 @@ public void manageModel(final RemoteModel model, final MethodCall call, final Me deleteModel(model, result); break; case CHECK: - Boolean downloaded = isModelDownloaded(model); - if (downloaded != null) result.success(downloaded); - else result.error("error", null, null); + isModelDownloaded( + model, + new CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + result.success(isDownloaded); + } + + @Override + public void onError(Exception e) { + result.error("error", e.toString(), null); + } + } + ); break; default: result.notImplemented(); @@ -51,42 +63,56 @@ public void manageModel(final RemoteModel model, final MethodCall call, final Me } public void downloadModel(RemoteModel remoteModel, DownloadConditions downloadConditions, final MethodChannel.Result result) { - if (isModelDownloaded(remoteModel)) { - result.success("success"); - return; - } - remoteModelManager.download(remoteModel, downloadConditions).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null)); - } + isModelDownloaded( + remoteModel, + new CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + if (isDownloaded) { + result.success("success"); + return; + } - public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) { - if (!isModelDownloaded(remoteModel)) { - result.success("success"); - return; - } - remoteModelManager.deleteDownloadedModel(remoteModel).addOnSuccessListener(aVoid -> result.success("success")).addOnFailureListener(e -> result.error("error", e.toString(), null)); - } + remoteModelManager.download(remoteModel, downloadConditions) + .addOnSuccessListener(aVoid -> result.success("success")) + .addOnFailureListener(e -> result.error("error", e.toString(), null)); + } - public Boolean isModelDownloaded(RemoteModel model) { - IsModelDownloaded myCallable = new IsModelDownloaded(remoteModelManager.isModelDownloaded(model)); - Future taskResult = executorService.submit(myCallable); - try { - return taskResult.get(); - } catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - } - return null; + @Override + public void onError(Exception e) { + result.error("error", e.toString(), null); + } + } + ); } -} -class IsModelDownloaded implements Callable { - final Task booleanTask; + public void deleteModel(RemoteModel remoteModel, final MethodChannel.Result result) { + isModelDownloaded(remoteModel, new CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + if (!isDownloaded) { + result.success("success"); + return; + } + remoteModelManager.deleteDownloadedModel(remoteModel) + .addOnSuccessListener(aVoid -> result.success("success")) + .addOnFailureListener(e -> result.error("error", e.toString(), null)); + } - public IsModelDownloaded(Task booleanTask) { - this.booleanTask = booleanTask; + @Override + public void onError(Exception e) { + result.error("error", e.toString(), null); + } + }); } - @Override - public Boolean call() throws Exception { - return Tasks.await(booleanTask); + public void isModelDownloaded(RemoteModel model, CheckModelIsDownloadedCallback callback) { + try { + remoteModelManager.isModelDownloaded(model) + .addOnFailureListener(callback::onError) + .addOnSuccessListener(callback::onCheckResult); + } catch (Exception e) { + callback.onError(e); + } } -} +} \ No newline at end of file diff --git a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java b/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java index 14005216..e100fe4c 100644 --- a/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java +++ b/packages/google_mlkit_digital_ink_recognition/android/src/main/java/com/google_mlkit_digital_ink_recognition/DigitalInkRecognizer.java @@ -54,11 +54,33 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result) DigitalInkRecognitionModel model = getModel(tag, result); if (model == null) return; - if (!genericModelManager.isModelDownloaded(model)) { - result.error("Model Error", "Model has not been downloaded yet ", null); - return; - } + genericModelManager.isModelDownloaded( + model, + new GenericModelManager.CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + if (!isDownloaded) { + result.error("Model Error", "Model has not been downloaded yet ", null); + return; + } + + handleInkDetectionIfModelDownloaded(call, result, model); + } + + @Override + public void onError(Exception e) { + result.error("Model download check failed", e.toString(), e); + } + } + ); + } + + private void handleInkDetectionIfModelDownloaded( + MethodCall call, + final MethodChannel.Result result, + DigitalInkRecognitionModel model + ) { String id = call.argument("id"); com.google.mlkit.vision.digitalink.DigitalInkRecognizer recognizer = instances.get(id); if (recognizer == null) { @@ -164,4 +186,4 @@ private DigitalInkRecognitionModel getModel(String tag, final MethodChannel.Resu } return DigitalInkRecognitionModel.builder(modelIdentifier).build(); } -} +} \ No newline at end of file diff --git a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java b/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java index 16a4729d..93bc20e5 100644 --- a/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java +++ b/packages/google_mlkit_image_labeling/android/src/main/java/com/google_mlkit_image_labeling/ImageLabelDetector.java @@ -79,12 +79,43 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result) CustomImageLabelerOptions labelerOptions = getLocalOptions(options); imageLabeler = ImageLabeling.getClient(labelerOptions); } else if (type.equals("remote")) { - CustomImageLabelerOptions labelerOptions = getRemoteOptions(options); - if (labelerOptions == null) { - result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); - return; - } - imageLabeler = ImageLabeling.getClient(labelerOptions); + float confidenceThreshold = (float) (double) options.get("confidenceThreshold"); + int maxCount = (int) options.get("maxCount"); + String name = (String) options.get("modelName"); + + FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build(); + CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build(); + + genericModelManager.isModelDownloaded( + remoteModel, + new GenericModelManager.CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + if (!isDownloaded) { + result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); + return; + } + + startImageLabelDetector( + ImageLabeling.getClient( + new CustomImageLabelerOptions.Builder(remoteModel) + .setConfidenceThreshold(confidenceThreshold) + .setMaxResultCount(maxCount) + .build() + ), + inputImage, + result + ); + } + + @Override + public void onError(Exception e) { + result.error("Model download check failed", e.getMessage(), e); + } + } + ); + + return; } else { String error = "Invalid model type: " + type; result.error(type, error, error); @@ -93,6 +124,10 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result) instances.put(id, imageLabeler); } + startImageLabelDetector(imageLabeler, inputImage, result); + } + + private void startImageLabelDetector(ImageLabeler imageLabeler, InputImage inputImage, MethodChannel.Result result) { imageLabeler.process(inputImage) .addOnSuccessListener(imageLabels -> { List> labels = new ArrayList<>(imageLabels.size()); @@ -131,24 +166,6 @@ private CustomImageLabelerOptions getLocalOptions(Map labelerOpt .build(); } - //Options for labeler to work with custom model. - private CustomImageLabelerOptions getRemoteOptions(Map labelerOptions) { - float confidenceThreshold = (float) (double) labelerOptions.get("confidenceThreshold"); - int maxCount = (int) labelerOptions.get("maxCount"); - String name = (String) labelerOptions.get("modelName"); - - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name).build(); - CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource).build(); - if (!genericModelManager.isModelDownloaded(remoteModel)) { - return null; - } - - return new CustomImageLabelerOptions.Builder(remoteModel) - .setConfidenceThreshold(confidenceThreshold) - .setMaxResultCount(maxCount) - .build(); - } - private void closeDetector(MethodCall call) { String id = call.argument("id"); ImageLabeler imageLabeler = instances.get(id); diff --git a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java b/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java index f15cc4ab..45cd674b 100644 --- a/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java +++ b/packages/google_mlkit_object_detection/android/src/main/java/com/google_mlkit_object_detection/ObjectDetector.java @@ -58,7 +58,7 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result } private void handleDetection(MethodCall call, final MethodChannel.Result result) { - Map imageData = (Map) call.argument("imageData"); + Map imageData = call.argument("imageData"); InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); if (inputImage == null) return; @@ -79,12 +79,55 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result) CustomObjectDetectorOptions detectorOptions = getLocalOptions(options); objectDetector = ObjectDetection.getClient(detectorOptions); } else if (type.equals("remote")) { - CustomObjectDetectorOptions detectorOptions = getRemoteOptions(options); - if (detectorOptions == null) { - result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); - return; - } - objectDetector = ObjectDetection.getClient(detectorOptions); + int mode = (int) options.get("mode"); + int finalMode = mode == 0 ? + CustomObjectDetectorOptions.STREAM_MODE : + CustomObjectDetectorOptions.SINGLE_IMAGE_MODE; + boolean classify = (boolean) options.get("classify"); + boolean multiple = (boolean) options.get("multiple"); + double threshold = (double) options.get("threshold"); + int maxLabels = (int) options.get("maxLabels"); + String name = (String) options.get("modelName"); + + FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name) + .build(); + CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource) + .build(); + + genericModelManager.isModelDownloaded( + remoteModel, + new GenericModelManager.CheckModelIsDownloadedCallback() { + @Override + public void onCheckResult(Boolean isDownloaded) { + if (!isDownloaded) { + result.error("Error Model has not been downloaded yet", "Model has not been downloaded yet", "Model has not been downloaded yet"); + return; + } + + CustomObjectDetectorOptions.Builder builder = new CustomObjectDetectorOptions.Builder(remoteModel) + .setDetectorMode(finalMode) + .setMaxPerObjectLabelCount(maxLabels) + .setClassificationConfidenceThreshold((float) threshold); + if (classify) builder.enableClassification(); + if (multiple) builder.enableMultipleObjects(); + + CustomObjectDetectorOptions customObjectDetectorOptions = builder.build(); + + startObjectDetection( + ObjectDetection.getClient(customObjectDetectorOptions), + inputImage, + result + ); + } + + @Override + public void onError(Exception e) { + result.error("Model download check failed", e.getMessage(), e); + } + } + ); + + return; } else { String error = "Invalid model type: " + type; result.error(type, error, error); @@ -93,6 +136,14 @@ private void handleDetection(MethodCall call, final MethodChannel.Result result) instances.put(id, objectDetector); } + startObjectDetection(objectDetector, inputImage, result); + } + + private void startObjectDetection( + com.google.mlkit.vision.objects.ObjectDetector objectDetector, + InputImage inputImage, + MethodChannel.Result result + ) { objectDetector.process(inputImage).addOnSuccessListener(detectedObjects -> { List> objects = new ArrayList<>(); for (DetectedObject detectedObject : detectedObjects) { @@ -149,34 +200,6 @@ private CustomObjectDetectorOptions getLocalOptions(Map options) return builder.build(); } - private CustomObjectDetectorOptions getRemoteOptions(Map options) { - int mode = (int) options.get("mode"); - mode = mode == 0 ? - CustomObjectDetectorOptions.STREAM_MODE : - CustomObjectDetectorOptions.SINGLE_IMAGE_MODE; - boolean classify = (boolean) options.get("classify"); - boolean multiple = (boolean) options.get("multiple"); - double threshold = (double) options.get("threshold"); - int maxLabels = (int) options.get("maxLabels"); - String name = (String) options.get("modelName"); - - FirebaseModelSource firebaseModelSource = new FirebaseModelSource.Builder(name) - .build(); - CustomRemoteModel remoteModel = new CustomRemoteModel.Builder(firebaseModelSource) - .build(); - if (!genericModelManager.isModelDownloaded(remoteModel)) { - return null; - } - - CustomObjectDetectorOptions.Builder builder = new CustomObjectDetectorOptions.Builder(remoteModel); - builder.setDetectorMode(mode); - if (classify) builder.enableClassification(); - if (multiple) builder.enableMultipleObjects(); - builder.setMaxPerObjectLabelCount(maxLabels); - builder.setClassificationConfidenceThreshold((float) threshold); - return builder.build(); - } - private void addData(Map addTo, Integer trackingId, Rect rect,