Skip to content

Commit e34ce38

Browse files
authored
feat: Subject Segmentation - added SubjectSegmenterOptions (#685)
1 parent 9a51c6e commit e34ce38

File tree

5 files changed

+202
-63
lines changed

5 files changed

+202
-63
lines changed

packages/example/lib/vision_detector_views/painters/subject_segmentation_painter.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class SubjectSegmentationPainter extends CustomPainter {
2121
void paint(Canvas canvas, Size size) {
2222
final int width = mask.width;
2323
final int height = mask.height;
24-
final List<Subject> subjects = mask.subjects;
24+
final List<Subject> subjects = mask.subjects ?? [];
2525

2626
final paint = Paint()..style = PaintingStyle.fill;
2727

@@ -30,7 +30,7 @@ class SubjectSegmentationPainter extends CustomPainter {
3030
final int startY = subject.startY;
3131
final int subjectWidth = subject.subjectWidth;
3232
final int subjectHeight = subject.subjectHeight;
33-
final List<double> confidences = subject.confidences;
33+
final List<double> confidences = subject.confidences ?? [];
3434

3535
for (int y = 0; y < subjectHeight; y++) {
3636
for (int x = 0; y < subjectWidth; x++) {

packages/example/lib/vision_detector_views/subject_segmenter_view.dart

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class SubjectSegmenterView extends StatefulWidget {
1111
}
1212

1313
class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
14-
final SubjectSegmenter _segmenter = SubjectSegmenter();
14+
final SubjectSegmenter _segmenter = SubjectSegmenter(
15+
options: SubjectSegmenterOptions(enableForegroundConfidenceMask: true));
1516
bool _canProcess = true;
1617
bool _isBusy = false;
1718
CustomPaint? _customPaint;
@@ -56,8 +57,7 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
5657
_customPaint = CustomPaint(painter: painter);
5758
} else {
5859
// TODO: set _customPaint to draw on top of image
59-
_text = 'There is a mask with ${mask.subjects.length} subjects';
60-
60+
_text = 'There is a mask with ${mask.subjects?.length} subjects';
6161
_customPaint = null;
6262
}
6363
_isBusy = false;

packages/google_mlkit_subject_segmentation/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ final InputImage inputImage;
8181
#### Create an instance of `SubjectSegmenter`
8282

8383
```dart
84-
final segmenter = SubjectSegmenter();
84+
final options = SubjectSegmenterOptions();
85+
final segmenter = SubjectSegmenter(options: options);
8586
```
8687

8788
#### Process image

packages/google_mlkit_subject_segmentation/android/src/main/java/com/google_mlkit_subject_segmentation/SubjectSegmenterProcess.java

Lines changed: 105 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
package com.google_mlkit_subject_segmentation;
22

33
import android.content.Context;
4+
import android.graphics.Bitmap;
45

56
import androidx.annotation.NonNull;
67

78
import com.google.mlkit.vision.common.InputImage;
89
import com.google.mlkit.vision.segmentation.subject.Subject;
910
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation;
11+
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult;
1012
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenter;
1113

14+
import java.io.ByteArrayInputStream;
15+
import java.io.ByteArrayOutputStream;
16+
import java.io.IOException;
17+
import java.io.OutputStream;
18+
import java.lang.reflect.Method;
1219
import java.util.ArrayList;
1320
import java.util.List;
1421
import java.nio.FloatBuffer;
1522
import java.util.HashMap;
1623
import java.util.Map;
24+
import java.util.Objects;
1725

1826
import io.flutter.Log;
1927
import io.flutter.plugin.common.MethodCall;
@@ -27,8 +35,6 @@ public class SubjectSegmenterProcess implements MethodChannel.MethodCallHandler
2735
private static final String CLOSE = "vision#closeSubjectSegmenter";
2836

2937
private final Context context;
30-
31-
private static final String TAG = "Logger";
3238

3339
private int imageWidth;
3440
private int imageHeight;
@@ -55,55 +61,119 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result
5561
}
5662
}
5763

58-
private SubjectSegmenter initialize(MethodCall call) {
59-
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder()
60-
.enableMultipleSubjects(new SubjectSegmenterOptions.SubjectResultOptions.Builder()
61-
.enableConfidenceMask().build());
62-
SubjectSegmenterOptions options = builder.build();
63-
return SubjectSegmentation.getClient(options);
64-
}
65-
66-
private void handleDetection(MethodCall call, MethodChannel.Result result){
67-
Map<String, Object> imageData = (Map<String, Object>) call.argument("imageData");
68-
InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result);
69-
if (inputImage == null) return;
64+
private void handleDetection(MethodCall call, MethodChannel.Result result) {
65+
InputImage inputImage = InputImageConverter.getInputImageFromData(call.argument("imageData"), context, result);
66+
if(inputImage == null) return;
7067
imageHeight = inputImage.getHeight();
7168
imageWidth = inputImage.getWidth();
69+
7270
String id = call.argument("id");
73-
SubjectSegmenter subjectSegmenter = instances.get(id);
74-
if (subjectSegmenter == null) {
75-
subjectSegmenter = initialize(call);
76-
instances.put(id, subjectSegmenter);
71+
SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call);
72+
73+
subjectSegmenter.process(inputImage)
74+
.addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, call, result))
75+
.addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e));
76+
77+
}
78+
79+
private SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) {
80+
return instances.computeIfAbsent(id, k -> initialize(call));
81+
}
82+
private SubjectSegmenter initialize(MethodCall call) {
83+
Map<String, Object> options = call.argument("options");
84+
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder();
85+
assert options != null;
86+
configureBuilder(builder, options);
87+
return SubjectSegmentation.getClient(builder.build());
88+
}
89+
90+
private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
91+
if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))){
92+
builder.enableForegroundBitmap();
93+
}
94+
if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){
95+
builder.enableForegroundConfidenceMask();
7796
}
97+
configureMultipleSubjects(builder, options);
98+
}
99+
100+
private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
101+
boolean enableMultiConfidenceMask = Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) ;
102+
boolean enableMultiSubjectBitmap = Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"));
103+
104+
if(enableMultiConfidenceMask || enableMultiSubjectBitmap) {
105+
SubjectSegmenterOptions.SubjectResultOptions.Builder subjectBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder();
106+
if(enableMultiConfidenceMask) subjectBuilder.enableConfidenceMask();
107+
if(enableMultiSubjectBitmap) subjectBuilder.enableSubjectBitmap();
108+
builder.enableMultipleSubjects(subjectBuilder.build());
109+
}
110+
}
111+
112+
private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodCall call, MethodChannel.Result result) {
113+
Map<String, Object> resultMap = new HashMap<>();
114+
Map<String, Object> options = call.argument("options");
115+
116+
assert options != null;
117+
if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) {
118+
addForegroundBitmap(resultMap, subjectSegmentationResult.getForegroundBitmap());
119+
}
120+
121+
if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){
122+
addConfidenceMask(resultMap, subjectSegmentationResult.getForegroundConfidenceMask());
123+
}
124+
if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) || Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) {
78125

79-
subjectSegmenter.process(inputImage)
80-
.addOnSuccessListener( subjectSegmentationResult -> {
81126
List<Map<String, Object>> subjectsData = new ArrayList<>();
82-
for(Subject subject : subjectSegmentationResult.getSubjects()){
83-
Map<String, Object> subjectData = getStringObjectMap(subject);
127+
for(Subject subject: subjectSegmentationResult.getSubjects()){
128+
Map<String, Object> subjectData = getStringObjectMap(subject, options);
84129
subjectsData.add(subjectData);
85130
}
86-
Map<String, Object> map = new HashMap<>();
87-
map.put("subjects", subjectsData);
88-
map.put("width", imageWidth);
89-
map.put("height", imageHeight);
90-
result.success(map);
91-
}).addOnFailureListener( e -> result.error("Subject segmentation failed!", e.getMessage(), e) );
131+
resultMap.put("subjects", subjectsData);
132+
}
133+
resultMap.put("width", imageWidth);
134+
resultMap.put("height", imageHeight);
135+
136+
result.success(resultMap);
137+
}
138+
139+
private void addForegroundBitmap(Map<String, Object> map, Bitmap bitmap) {
140+
if(bitmap != null) {
141+
map.put("bitmap", getBitmapBytes(bitmap));
142+
}
143+
}
144+
145+
private void addConfidenceMask(Map<String, Object> map, FloatBuffer mask) {
146+
if(mask != null) {
147+
map.put("confidences", getConfidences(mask));
148+
}
149+
}
150+
151+
private static float[] getConfidences(FloatBuffer floatBuffer) {
152+
float[] confidences = new float[floatBuffer.remaining()];
153+
floatBuffer.get(confidences);
154+
return confidences;
155+
}
156+
157+
private static byte[] getBitmapBytes(Bitmap bitmap) {
158+
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
159+
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream);
160+
return outputStream.toByteArray();
92161
}
93162

163+
94164
@NonNull
95-
private static Map<String, Object> getStringObjectMap(Subject subject) {
165+
private static Map<String, Object> getStringObjectMap(Subject subject, Map<String, Object> options) {
96166
Map<String, Object> subjectData = new HashMap<>();
97167
subjectData.put("startX", subject.getStartX());
98168
subjectData.put("startY", subject.getStartY());
99169
subjectData.put("width", subject.getWidth());
100170
subjectData.put("height", subject.getHeight());
101-
102-
FloatBuffer confidenceMask = subject.getConfidenceMask();
103-
assert confidenceMask != null;
104-
float[] confidences = new float[confidenceMask.remaining()];
105-
confidenceMask.get(confidences);
106-
subjectData.put("confidences", confidences);
171+
if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask"))){
172+
subjectData.put("confidences", getConfidences(Objects.requireNonNull(subject.getConfidenceMask())));
173+
}
174+
if(Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) {
175+
subjectData.put("bitmap", getBitmapBytes(Objects.requireNonNull(subject.getBitmap())));
176+
}
107177
return subjectData;
108178
}
109179

0 commit comments

Comments
 (0)