Skip to content

Add imagen editing cases to the Firebase AI quickstart #2702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
package com.google.firebase.quickstart.ai

import android.content.Context
import android.content.res.Resources
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.type.Dimensions
import com.google.firebase.ai.type.FunctionDeclaration
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenBackgroundMask
import com.google.firebase.ai.type.ImagenEditMode
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenMaskReference
import com.google.firebase.ai.type.ImagenRawImage
import com.google.firebase.ai.type.ImagenStyleReference
import com.google.firebase.ai.type.ImagenSubjectReference
import com.google.firebase.ai.type.ImagenSubjectReferenceType
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.ResponseModality
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.Tool
import com.google.firebase.ai.type.content
import com.google.firebase.ai.type.generationConfig
import com.google.firebase.ai.type.toImagenInlineImage
import com.google.firebase.quickstart.ai.ui.navigation.Category
import com.google.firebase.quickstart.ai.ui.navigation.Sample

@OptIn(PublicPreviewAPI::class)
val FIREBASE_AI_SAMPLES = listOf(
Sample(
title = "Travel tips",
Expand Down Expand Up @@ -131,6 +148,93 @@ val FIREBASE_AI_SAMPLES = listOf(
text(
"A photo of a modern building with water in the background"
)
},
allowEmptyPrompt = false,
generateImages = { model: ImagenModel, inputText: String, _: Bitmap? ->
model.generateImages(
inputText
)
}
),
Sample(
title = "Imagen 3 - Inpainting (Vertex AI)",
description = "Replace the background of an image using Imagen 3",
modelName = "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content { text("A sunny beach") },
includeAttach = true,
allowEmptyPrompt = true,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
model.editImage(
listOf(ImagenRawImage(bitmap!!.toImagenInlineImage()), ImagenBackgroundMask()),
inputText,
ImagenEditingConfig(ImagenEditMode.INPAINT_INSERTION)
)
}
),
Sample(
title = "Imagen 3 - Outpainting (Vertex AI)",
description = "Expand an image by drawing in more background",
modelName = "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content { text("") },
includeAttach = true,
allowEmptyPrompt = true,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
val dimensions = Dimensions(bitmap!!.width * 2, bitmap.height * 2)
model.editImage(
ImagenMaskReference.generateMaskAndPadForOutpainting(bitmap.toImagenInlineImage(), dimensions),
inputText,
ImagenEditingConfig(ImagenEditMode.OUTPAINT)
)
}
),
Sample(
title = "Imagen 3 - Subject Reference (Vertex AI)",
description = "generate an image using a referenced subject (must be an animal)",
modelName = "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content { text("<subject> flying through space") },
includeAttach = true,
allowEmptyPrompt = false,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
model.editImage(
listOf(
ImagenSubjectReference(
referenceId = 1,
image = bitmap!!.toImagenInlineImage(),
subjectType = ImagenSubjectReferenceType.ANIMAL,
description = "An animal"
)
),
"Create an image about An animal [1] to match the description: " +
inputText.replace("<subject>", "An animal [1]"),
)
}
),
Sample(
title = "Imagen 3 - Style Transfer (Vertex AI)",
description = "Change the art style of an cat picture using a reference",
modelName = "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content { text("A picture of a cat") },
includeAttach = true,
allowEmptyPrompt = true,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
model.editImage(
listOf(
ImagenRawImage(MainActivity.catImage.toImagenInlineImage()),
ImagenStyleReference(bitmap!!.toImagenInlineImage(), 1, "an art style")),
"Generate an image in an art style [1] based on the following caption: $inputText",
)
}
),
Sample(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.google.firebase.quickstart.ai

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
Expand All @@ -22,6 +24,7 @@ import androidx.navigation.NavDestination
import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.rememberNavController
import com.google.firebase.ai.type.toImagenInlineImage
import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeRoute
import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeScreen
import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenRoute
Expand All @@ -36,6 +39,7 @@ class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
catImage = BitmapFactory.decodeResource(applicationContext.resources, R.drawable.cat)
setContent {
val navController = rememberNavController()

Expand Down Expand Up @@ -110,4 +114,7 @@ class MainActivity : ComponentActivity() {
})
}
}
companion object{
lateinit var catImage: Bitmap
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package com.google.firebase.quickstart.ai.feature.media.imagen

import android.net.Uri
import android.provider.OpenableColumns
import android.text.format.Formatter
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
Expand All @@ -19,14 +24,20 @@ import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.firebase.quickstart.ai.feature.text.Attachment
import com.google.firebase.quickstart.ai.feature.text.AttachmentsList
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable

@Serializable
Expand All @@ -40,6 +51,34 @@ fun ImagenScreen(
val errorMessage by imagenViewModel.errorMessage.collectAsStateWithLifecycle()
val isLoading by imagenViewModel.isLoading.collectAsStateWithLifecycle()
val generatedImages by imagenViewModel.generatedBitmaps.collectAsStateWithLifecycle()
val includeAttach by imagenViewModel.includeAttach.collectAsStateWithLifecycle()
val allowEmptyPrompt by imagenViewModel.allowEmptyPrompt.collectAsStateWithLifecycle()
val attachedImage by imagenViewModel.attachedImage.collectAsStateWithLifecycle()
val context = LocalContext.current
val contentResolver = context.contentResolver
val scope = rememberCoroutineScope()
val openDocument = rememberLauncherForActivityResult(ActivityResultContracts.OpenDocument()) { optionalUri: Uri? ->
optionalUri?.let { uri ->
var fileName: String? = null
// Fetch file name and size
contentResolver.query(uri, null, null, null, null)?.use { cursor ->
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
val sizeIndex = cursor.getColumnIndex(OpenableColumns.SIZE)
cursor.moveToFirst()
val humanReadableSize = Formatter.formatShortFileSize(
context, cursor.getLong(sizeIndex)
)
fileName = "${cursor.getString(nameIndex)} ($humanReadableSize)"
}

contentResolver.openInputStream(uri)?.use { stream ->
val bytes = stream.readBytes()
scope.launch {
imagenViewModel.attachImage(bytes)
}
}
}
}

Column(
modifier = Modifier
Expand All @@ -59,9 +98,22 @@ fun ImagenScreen(
.padding(16.dp)
.fillMaxWidth()
)
if (includeAttach) {
if (attachedImage != null) {
AttachmentsList(listOf(Attachment("", attachedImage)))
}
TextButton(
onClick = {
openDocument.launch(arrayOf("image/*"))
},
modifier = Modifier
.padding(end = 16.dp, bottom = 16.dp)
.align(Alignment.End)
) { Text("Attach") }
}
TextButton(
onClick = {
if (imagenPrompt.isNotBlank()) {
if (allowEmptyPrompt || imagenPrompt.isNotBlank()) {
imagenViewModel.generateImages(imagenPrompt)
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package com.google.firebase.quickstart.ai.feature.media.imagen

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import androidx.lifecycle.SavedStateHandle
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import androidx.navigation.toRoute
import com.google.firebase.Firebase
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.ai
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenAspectRatio
import com.google.firebase.ai.type.ImagenEditMode
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenImageFormat
import com.google.firebase.ai.type.ImagenPersonFilterLevel
import com.google.firebase.ai.type.ImagenSafetyFilterLevel
Expand All @@ -20,6 +22,7 @@ import com.google.firebase.ai.type.imagenGenerationConfig
import com.google.firebase.quickstart.ai.FIREBASE_AI_SAMPLES
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch

@OptIn(PublicPreviewAPI::class)
Expand All @@ -36,6 +39,15 @@ class ImagenViewModel(
private val _isLoading = MutableStateFlow(false)
val isLoading: StateFlow<Boolean> = _isLoading

private val _includeAttach = MutableStateFlow(sample.includeAttach)
val includeAttach: StateFlow<Boolean> = _includeAttach

private val _allowEmptyPrompt = MutableStateFlow(sample.allowEmptyPrompt)
val allowEmptyPrompt: StateFlow<Boolean> = _allowEmptyPrompt

private val _attachedImage = MutableStateFlow<Bitmap?>(null)
val attachedImage: StateFlow<Bitmap?> = _attachedImage

private val _generatedBitmaps = MutableStateFlow(listOf<Bitmap>())
val generatedBitmaps: StateFlow<List<Bitmap>> = _generatedBitmaps

Expand All @@ -45,15 +57,14 @@ class ImagenViewModel(
init {
val config = imagenGenerationConfig {
numberOfImages = 4
aspectRatio = ImagenAspectRatio.SQUARE_1x1
imageFormat = ImagenImageFormat.png()
}
val settings = ImagenSafetySettings(
safetyFilterLevel = ImagenSafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
personFilterLevel = ImagenPersonFilterLevel.BLOCK_ALL
)
imagenModel = Firebase.ai(
backend = GenerativeBackend.googleAI()
backend = sample.backend
).imagenModel(
modelName = sample.modelName ?: "imagen-3.0-generate-002",
generationConfig = config,
Expand All @@ -65,9 +76,7 @@ class ImagenViewModel(
viewModelScope.launch {
_isLoading.value = true
try {
val imageResponse = imagenModel.generateImages(
inputText
)
val imageResponse = sample.generateImages!!(imagenModel, inputText, attachedImage.first())
_generatedBitmaps.value = imageResponse.images.map { it.asBitmap() }
_errorMessage.value = null // clear error message
} catch (e: Exception) {
Expand All @@ -77,4 +86,10 @@ class ImagenViewModel(
}
}
}

suspend fun attachImage(
fileInBytes: ByteArray,
) {
_attachedImage.emit(BitmapFactory.decodeByteArray(fileInBytes, 0, fileInBytes.size))
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package com.google.firebase.quickstart.ai.ui.navigation

import android.content.Context
import android.graphics.Bitmap
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.GenerationConfig
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenGenerationResponse
import com.google.firebase.ai.type.ImagenInlineImage
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.Tool
import java.util.UUID

Expand All @@ -17,6 +23,7 @@ enum class Category(
FUNCTION_CALLING("Function calling"),
}

@OptIn(PublicPreviewAPI::class)
data class Sample(
val id: String = UUID.randomUUID().toString(), // used for navigation
val title: String,
Expand All @@ -30,5 +37,8 @@ data class Sample(
val systemInstructions: Content? = null,
val generationConfig: GenerationConfig? = null,
val chatHistory: List<Content> = emptyList(),
val tools: List<Tool>? = null
val tools: List<Tool>? = null,
val includeAttach: Boolean = false,
val allowEmptyPrompt: Boolean = false,
val generateImages: (suspend (ImagenModel, String, Bitmap?) -> ImagenGenerationResponse<ImagenInlineImage>)? = null
)
Binary file added firebase-ai/app/src/main/res/drawable/cat.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading