diff --git a/packages/example/lib/vision_detector_views/painters/subject_segmentation_painter.dart b/packages/example/lib/vision_detector_views/painters/subject_segmentation_painter.dart index 5f801165..78023c99 100644 --- a/packages/example/lib/vision_detector_views/painters/subject_segmentation_painter.dart +++ b/packages/example/lib/vision_detector_views/painters/subject_segmentation_painter.dart @@ -21,7 +21,7 @@ class SubjectSegmentationPainter extends CustomPainter { void paint(Canvas canvas, Size size) { final int width = mask.width; final int height = mask.height; - final List subjects = mask.subjects; + final List subjects = mask.subjects ?? []; final paint = Paint()..style = PaintingStyle.fill; @@ -30,7 +30,7 @@ class SubjectSegmentationPainter extends CustomPainter { final int startY = subject.startY; final int subjectWidth = subject.subjectWidth; final int subjectHeight = subject.subjectHeight; - final List confidences = subject.confidences; + final List confidences = subject.confidences ?? []; for (int y = 0; y < subjectHeight; y++) { for (int x = 0; y < subjectWidth; x++) { diff --git a/packages/example/lib/vision_detector_views/subject_segmenter_view.dart b/packages/example/lib/vision_detector_views/subject_segmenter_view.dart index 060cb73e..4977e65d 100644 --- a/packages/example/lib/vision_detector_views/subject_segmenter_view.dart +++ b/packages/example/lib/vision_detector_views/subject_segmenter_view.dart @@ -11,7 +11,8 @@ class SubjectSegmenterView extends StatefulWidget { } class _SubjectSegmenterViewState extends State { - final SubjectSegmenter _segmenter = SubjectSegmenter(); + final SubjectSegmenter _segmenter = SubjectSegmenter( + options: SubjectSegmenterOptions(enableForegroundConfidenceMask: true)); bool _canProcess = true; bool _isBusy = false; CustomPaint? _customPaint; @@ -56,8 +57,7 @@ class _SubjectSegmenterViewState extends State { _customPaint = CustomPaint(painter: painter); } else { // TODO: set _customPaint to draw on top of image - _text = 'There is a mask with ${mask.subjects.length} subjects'; - + _text = 'There is a mask with ${mask.subjects?.length} subjects'; _customPaint = null; } _isBusy = false; diff --git a/packages/google_mlkit_subject_segmentation/README.md b/packages/google_mlkit_subject_segmentation/README.md index a0eea991..34364349 100644 --- a/packages/google_mlkit_subject_segmentation/README.md +++ b/packages/google_mlkit_subject_segmentation/README.md @@ -81,7 +81,8 @@ final InputImage inputImage; #### Create an instance of `SubjectSegmenter` ```dart -final segmenter = SubjectSegmenter(); +final options = SubjectSegmenterOptions(); +final segmenter = SubjectSegmenter(options: options); ``` #### Process image diff --git a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenterProcess.java b/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenterProcess.java index bab91b18..eee52a71 100644 --- a/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenterProcess.java +++ b/packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenterProcess.java @@ -1,19 +1,27 @@ package com.google_mlkit_subject_segmentation; import android.content.Context; +import android.graphics.Bitmap; import androidx.annotation.NonNull; import com.google.mlkit.vision.common.InputImage; import com.google.mlkit.vision.segmentation.subject.Subject; import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation; +import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult; import com.google.mlkit.vision.segmentation.subject.SubjectSegmenter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import io.flutter.Log; import io.flutter.plugin.common.MethodCall; @@ -27,8 +35,6 @@ public class SubjectSegmenterProcess implements MethodChannel.MethodCallHandler private static final String CLOSE = "vision#closeSubjectSegmenter"; private final Context context; - - private static final String TAG = "Logger"; private int imageWidth; private int imageHeight; @@ -55,55 +61,119 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result } } - private SubjectSegmenter initialize(MethodCall call) { - SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder() - .enableMultipleSubjects(new SubjectSegmenterOptions.SubjectResultOptions.Builder() - .enableConfidenceMask().build()); - SubjectSegmenterOptions options = builder.build(); - return SubjectSegmentation.getClient(options); - } - - private void handleDetection(MethodCall call, MethodChannel.Result result){ - Map imageData = (Map) call.argument("imageData"); - InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result); - if (inputImage == null) return; + private void handleDetection(MethodCall call, MethodChannel.Result result) { + InputImage inputImage = InputImageConverter.getInputImageFromData(call.argument("imageData"), context, result); + if(inputImage == null) return; imageHeight = inputImage.getHeight(); imageWidth = inputImage.getWidth(); + String id = call.argument("id"); - SubjectSegmenter subjectSegmenter = instances.get(id); - if (subjectSegmenter == null) { - subjectSegmenter = initialize(call); - instances.put(id, subjectSegmenter); + SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call); + + subjectSegmenter.process(inputImage) + .addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, call, result)) + .addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e)); + + } + + private SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) { + return instances.computeIfAbsent(id, k -> initialize(call)); + } + private SubjectSegmenter initialize(MethodCall call) { + Map options = call.argument("options"); + SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder(); + assert options != null; + configureBuilder(builder, options); + return SubjectSegmentation.getClient(builder.build()); + } + + private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map options) { + if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))){ + builder.enableForegroundBitmap(); + } + if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){ + builder.enableForegroundConfidenceMask(); } + configureMultipleSubjects(builder, options); + } + + private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map options) { + boolean enableMultiConfidenceMask = Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) ; + boolean enableMultiSubjectBitmap = Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap")); + + if(enableMultiConfidenceMask || enableMultiSubjectBitmap) { + SubjectSegmenterOptions.SubjectResultOptions.Builder subjectBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder(); + if(enableMultiConfidenceMask) subjectBuilder.enableConfidenceMask(); + if(enableMultiSubjectBitmap) subjectBuilder.enableSubjectBitmap(); + builder.enableMultipleSubjects(subjectBuilder.build()); + } + } + + private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodCall call, MethodChannel.Result result) { + Map resultMap = new HashMap<>(); + Map options = call.argument("options"); + + assert options != null; + if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) { + addForegroundBitmap(resultMap, subjectSegmentationResult.getForegroundBitmap()); + } + + if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){ + addConfidenceMask(resultMap, subjectSegmentationResult.getForegroundConfidenceMask()); + } + if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) || Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) { - subjectSegmenter.process(inputImage) - .addOnSuccessListener( subjectSegmentationResult -> { List> subjectsData = new ArrayList<>(); - for(Subject subject : subjectSegmentationResult.getSubjects()){ - Map subjectData = getStringObjectMap(subject); + for(Subject subject: subjectSegmentationResult.getSubjects()){ + Map subjectData = getStringObjectMap(subject, options); subjectsData.add(subjectData); } - Map map = new HashMap<>(); - map.put("subjects", subjectsData); - map.put("width", imageWidth); - map.put("height", imageHeight); - result.success(map); - }).addOnFailureListener( e -> result.error("Subject segmentation failed!", e.getMessage(), e) ); + resultMap.put("subjects", subjectsData); + } + resultMap.put("width", imageWidth); + resultMap.put("height", imageHeight); + + result.success(resultMap); + } + + private void addForegroundBitmap(Map map, Bitmap bitmap) { + if(bitmap != null) { + map.put("bitmap", getBitmapBytes(bitmap)); + } + } + + private void addConfidenceMask(Map map, FloatBuffer mask) { + if(mask != null) { + map.put("confidences", getConfidences(mask)); + } + } + + private static float[] getConfidences(FloatBuffer floatBuffer) { + float[] confidences = new float[floatBuffer.remaining()]; + floatBuffer.get(confidences); + return confidences; + } + + private static byte[] getBitmapBytes(Bitmap bitmap) { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream); + return outputStream.toByteArray(); } + @NonNull - private static Map getStringObjectMap(Subject subject) { + private static Map getStringObjectMap(Subject subject, Map options) { Map subjectData = new HashMap<>(); subjectData.put("startX", subject.getStartX()); subjectData.put("startY", subject.getStartY()); subjectData.put("width", subject.getWidth()); subjectData.put("height", subject.getHeight()); - - FloatBuffer confidenceMask = subject.getConfidenceMask(); - assert confidenceMask != null; - float[] confidences = new float[confidenceMask.remaining()]; - confidenceMask.get(confidences); - subjectData.put("confidences", confidences); + if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask"))){ + subjectData.put("confidences", getConfidences(Objects.requireNonNull(subject.getConfidenceMask()))); + } + if(Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) { + subjectData.put("bitmap", getBitmapBytes(Objects.requireNonNull(subject.getBitmap()))); + } return subjectData; } diff --git a/packages/google_mlkit_subject_segmentation/lib/src/subject_segmenter.dart b/packages/google_mlkit_subject_segmentation/lib/src/subject_segmenter.dart index c5e6e534..07ca2d94 100644 --- a/packages/google_mlkit_subject_segmentation/lib/src/subject_segmenter.dart +++ b/packages/google_mlkit_subject_segmentation/lib/src/subject_segmenter.dart @@ -1,6 +1,10 @@ +// ignore_for_file: unnecessary_lambdas + import 'package:flutter/services.dart'; import 'package:google_mlkit_commons/google_mlkit_commons.dart'; +import '../google_mlkit_subject_segmentation.dart'; + /// A detector that performs segmentation on a given [InputImage]. class SubjectSegmenter { /// A platform channel used to communicate with native code for segmentation @@ -10,15 +14,22 @@ class SubjectSegmenter { /// A unique identifier for the segmentation session, generated using the current timestamp final id = DateTime.now().microsecondsSinceEpoch.toString(); + /// The options for the subject segmenter + final SubjectSegmenterOptions options; + + /// Constructor to create an instance of [SubjectSegmention]. + SubjectSegmenter({required this.options}); + /// Processes the given [InputImage] for segmentation. /// /// Sends the [InputImage] data to the natvie platform via the method channel - /// Returns the segmentation mask in the given image or nil if there was an error. + /// Returns the segmentation mask in the given image. Future processImage(InputImage inputImage) async { final results = await _channel .invokeMethod('vision#startSubjectSegmenter', { 'id': id, 'imageData': inputImage.toJson(), + 'options': options.toJson(), }); // Convert the JSON response from the platform into a SubjectSegmenterMask instance. final SubjectSegmenterMask masks = SubjectSegmenterMask.fromJson(results); @@ -33,6 +44,47 @@ class SubjectSegmenter { _channel.invokeMethod('vision#closeSubjectSegmenter', {'id': id}); } +/// Immutable options for configuring features of [SubjectSegmention]. +/// +/// Used to configure features such as foreground confidence mask, foreground bitmap, multi confidence mask +/// or multi subject bitmap +class SubjectSegmenterOptions { + /// Constructor for [SubjectSegmenterOptions]. + /// + /// The parameter to enable options + /// NOTE: To improve memory efficiency, it is recommended to only enable the necessary options. + SubjectSegmenterOptions({ + this.enableForegroundConfidenceMask = true, + this.enableForegroundBitmap = false, + this.enableMultiConfidenceMask = false, + this.enableMultiSubjectBitmap = false, + }); + + /// + /// Enables foreground confidence mask. + final bool enableForegroundConfidenceMask; + + /// + /// Enables foreground bitmap + final bool enableForegroundBitmap; + + /// + /// Enables confidence mask for segmented Subjects + final bool enableMultiConfidenceMask; + + /// + /// Enables subject bitmap for segmented Subjects. + final bool enableMultiSubjectBitmap; + + /// Returns a json representation of an instance of [SubjectSegmenterOptions]. + Map toJson() => { + 'enableForegroundConfidenceMask': enableForegroundConfidenceMask, + 'enableForegroundBitmap': enableForegroundBitmap, + 'enableMultiConfidenceMask': enableMultiConfidenceMask, + 'enableMultiSubjectBitmap': enableMultiSubjectBitmap, + }; +} + /// A data class that represents the segmentation mask returned by the [SubjectSegmenterMask] class SubjectSegmenterMask { /// The width of the segmentation mask @@ -41,27 +93,37 @@ class SubjectSegmenterMask { /// The height of the segmentation mask final int height; + /// The masked bitmap for the input image + final Uint8List? bitmap; + + /// A list of forground confidence mask for the input image + final List? confidences; + /// A list of subjects detected in the image, each respresented by a [Subject] instance - final List subjects; + final List? subjects; /// Constructor to create a instance of [SubjectSegmenterMask]. - /// - /// The [width] and [height] represent the dimensions of the mark, - /// and [subjects] is a list of detected subjects SubjectSegmenterMask({ required this.width, required this.height, - required this.subjects, + this.subjects, + this.bitmap, + this.confidences, }); /// Returns an instance of [SubjectSegmenterMask] from json factory SubjectSegmenterMask.fromJson(Map json) { - final List> list = json['subjects']; - final List subjects = list.map(Subject.fromJson).toList(); + List? subjects; + if (json['subjects'] != null) { + subjects = + json['subjects'].map((json) => Subject.fromJson(json)).toList(); + } return SubjectSegmenterMask( width: json['width'] as int, height: json['height'] as int, subjects: subjects, + confidences: json['confidences'], + bitmap: json['bitmap'], ); } } @@ -81,25 +143,31 @@ class Subject { final int subjectHeight; /// A list of confidence values for the detected subject. - final List confidences; - - Subject( - {required this.startX, - required this.startY, - required this.subjectWidth, - required this.subjectHeight, - required this.confidences}); + final List? confidences; + + /// The masked bitmap of the subject + final Uint8List? bitmap; + + Subject({ + required this.startX, + required this.startY, + required this.subjectWidth, + required this.subjectHeight, + this.confidences, + this.bitmap, + }); /// Creates an instance of [Subject] from a JSON object. /// /// This factory constructor is used to convert JSON data into a [Subject] object. - factory Subject.fromJson(Map json) { return Subject( - startX: json['startX'] as int, - startY: json['startY'] as int, - subjectWidth: json['width'] as int, - subjectHeight: json['height'] as int, - confidences: json['confidences']); + startX: json['startX'] as int, + startY: json['startY'] as int, + subjectWidth: json['width'] as int, + subjectHeight: json['height'] as int, + confidences: json['confidences'], + bitmap: json['bitmap'], + ); } }