diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/FirebaseAISamples.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/FirebaseAISamples.kt index c61a12aa0a..c4cdd22793 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/FirebaseAISamples.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/FirebaseAISamples.kt @@ -1,15 +1,32 @@ package com.google.firebase.quickstart.ai +import android.content.Context +import android.content.res.Resources +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import com.google.firebase.ai.ImagenModel +import com.google.firebase.ai.type.Dimensions import com.google.firebase.ai.type.FunctionDeclaration import com.google.firebase.ai.type.GenerativeBackend +import com.google.firebase.ai.type.ImagenBackgroundMask +import com.google.firebase.ai.type.ImagenEditMode +import com.google.firebase.ai.type.ImagenEditingConfig +import com.google.firebase.ai.type.ImagenMaskReference +import com.google.firebase.ai.type.ImagenRawImage +import com.google.firebase.ai.type.ImagenStyleReference +import com.google.firebase.ai.type.ImagenSubjectReference +import com.google.firebase.ai.type.ImagenSubjectReferenceType +import com.google.firebase.ai.type.PublicPreviewAPI import com.google.firebase.ai.type.ResponseModality import com.google.firebase.ai.type.Schema import com.google.firebase.ai.type.Tool import com.google.firebase.ai.type.content import com.google.firebase.ai.type.generationConfig +import com.google.firebase.ai.type.toImagenInlineImage import com.google.firebase.quickstart.ai.ui.navigation.Category import com.google.firebase.quickstart.ai.ui.navigation.Sample +@OptIn(PublicPreviewAPI::class) val FIREBASE_AI_SAMPLES = listOf( Sample( title = "Travel tips", @@ -131,6 +148,93 @@ val FIREBASE_AI_SAMPLES = listOf( text( "A photo of a modern building with water in the background" ) + }, + allowEmptyPrompt = false, + generateImages = { model: ImagenModel, inputText: String, _: Bitmap? -> + model.generateImages( + inputText + ) + } + ), + Sample( + title = "Imagen 3 - Inpainting (Vertex AI)", + description = "Replace the background of an image using Imagen 3", + modelName = "imagen-3.0-capability-001", + backend = GenerativeBackend.vertexAI(), + navRoute = "imagen", + categories = listOf(Category.IMAGE), + initialPrompt = content { text("A sunny beach") }, + includeAttach = true, + allowEmptyPrompt = true, + generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? -> + model.editImage( + listOf(ImagenRawImage(bitmap!!.toImagenInlineImage()), ImagenBackgroundMask()), + inputText, + ImagenEditingConfig(ImagenEditMode.INPAINT_INSERTION) + ) + } + ), + Sample( + title = "Imagen 3 - Outpainting (Vertex AI)", + description = "Expand an image by drawing in more background", + modelName = "imagen-3.0-capability-001", + backend = GenerativeBackend.vertexAI(), + navRoute = "imagen", + categories = listOf(Category.IMAGE), + initialPrompt = content { text("") }, + includeAttach = true, + allowEmptyPrompt = true, + generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? -> + val dimensions = Dimensions(bitmap!!.width * 2, bitmap.height * 2) + model.editImage( + ImagenMaskReference.generateMaskAndPadForOutpainting(bitmap.toImagenInlineImage(), dimensions), + inputText, + ImagenEditingConfig(ImagenEditMode.OUTPAINT) + ) + } + ), + Sample( + title = "Imagen 3 - Subject Reference (Vertex AI)", + description = "generate an image using a referenced subject (must be an animal)", + modelName = "imagen-3.0-capability-001", + backend = GenerativeBackend.vertexAI(), + navRoute = "imagen", + categories = listOf(Category.IMAGE), + initialPrompt = content { text(" flying through space") }, + includeAttach = true, + allowEmptyPrompt = false, + generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? -> + model.editImage( + listOf( + ImagenSubjectReference( + referenceId = 1, + image = bitmap!!.toImagenInlineImage(), + subjectType = ImagenSubjectReferenceType.ANIMAL, + description = "An animal" + ) + ), + "Create an image about An animal [1] to match the description: " + + inputText.replace("", "An animal [1]"), + ) + } + ), + Sample( + title = "Imagen 3 - Style Transfer (Vertex AI)", + description = "Change the art style of an cat picture using a reference", + modelName = "imagen-3.0-capability-001", + backend = GenerativeBackend.vertexAI(), + navRoute = "imagen", + categories = listOf(Category.IMAGE), + initialPrompt = content { text("A picture of a cat") }, + includeAttach = true, + allowEmptyPrompt = true, + generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? -> + model.editImage( + listOf( + ImagenRawImage(MainActivity.catImage.toImagenInlineImage()), + ImagenStyleReference(bitmap!!.toImagenInlineImage(), 1, "an art style")), + "Generate an image in an art style [1] based on the following caption: $inputText", + ) } ), Sample( diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt index 998e4612a7..a1508d0d2d 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt @@ -1,5 +1,7 @@ package com.google.firebase.quickstart.ai +import android.graphics.Bitmap +import android.graphics.BitmapFactory import android.os.Bundle import androidx.activity.ComponentActivity import androidx.activity.compose.setContent @@ -22,6 +24,7 @@ import androidx.navigation.NavDestination import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.compose.rememberNavController +import com.google.firebase.ai.type.toImagenInlineImage import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeRoute import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeScreen import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenRoute @@ -36,6 +39,7 @@ class MainActivity : ComponentActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) enableEdgeToEdge() + catImage = BitmapFactory.decodeResource(applicationContext.resources, R.drawable.cat) setContent { val navController = rememberNavController() @@ -110,4 +114,7 @@ class MainActivity : ComponentActivity() { }) } } + companion object{ + lateinit var catImage: Bitmap + } } diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenScreen.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenScreen.kt index e5d654a895..bfe7896c05 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenScreen.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenScreen.kt @@ -1,5 +1,10 @@ package com.google.firebase.quickstart.ai.feature.media.imagen +import android.net.Uri +import android.provider.OpenableColumns +import android.text.format.Formatter +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.contract.ActivityResultContracts import androidx.compose.foundation.Image import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -19,14 +24,20 @@ import androidx.compose.material3.TextButton import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.saveable.rememberSaveable import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.asImageBitmap +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.lifecycle.viewmodel.compose.viewModel +import com.google.firebase.quickstart.ai.feature.text.Attachment +import com.google.firebase.quickstart.ai.feature.text.AttachmentsList +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch import kotlinx.serialization.Serializable @Serializable @@ -40,6 +51,34 @@ fun ImagenScreen( val errorMessage by imagenViewModel.errorMessage.collectAsStateWithLifecycle() val isLoading by imagenViewModel.isLoading.collectAsStateWithLifecycle() val generatedImages by imagenViewModel.generatedBitmaps.collectAsStateWithLifecycle() + val includeAttach by imagenViewModel.includeAttach.collectAsStateWithLifecycle() + val allowEmptyPrompt by imagenViewModel.allowEmptyPrompt.collectAsStateWithLifecycle() + val attachedImage by imagenViewModel.attachedImage.collectAsStateWithLifecycle() + val context = LocalContext.current + val contentResolver = context.contentResolver + val scope = rememberCoroutineScope() + val openDocument = rememberLauncherForActivityResult(ActivityResultContracts.OpenDocument()) { optionalUri: Uri? -> + optionalUri?.let { uri -> + var fileName: String? = null + // Fetch file name and size + contentResolver.query(uri, null, null, null, null)?.use { cursor -> + val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME) + val sizeIndex = cursor.getColumnIndex(OpenableColumns.SIZE) + cursor.moveToFirst() + val humanReadableSize = Formatter.formatShortFileSize( + context, cursor.getLong(sizeIndex) + ) + fileName = "${cursor.getString(nameIndex)} ($humanReadableSize)" + } + + contentResolver.openInputStream(uri)?.use { stream -> + val bytes = stream.readBytes() + scope.launch { + imagenViewModel.attachImage(bytes) + } + } + } + } Column( modifier = Modifier @@ -59,9 +98,22 @@ fun ImagenScreen( .padding(16.dp) .fillMaxWidth() ) + if (includeAttach) { + if (attachedImage != null) { + AttachmentsList(listOf(Attachment("", attachedImage))) + } + TextButton( + onClick = { + openDocument.launch(arrayOf("image/*")) + }, + modifier = Modifier + .padding(end = 16.dp, bottom = 16.dp) + .align(Alignment.End) + ) { Text("Attach") } + } TextButton( onClick = { - if (imagenPrompt.isNotBlank()) { + if (allowEmptyPrompt || imagenPrompt.isNotBlank()) { imagenViewModel.generateImages(imagenPrompt) } }, diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenViewModel.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenViewModel.kt index bd1b58b018..9ca82bdaec 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenViewModel.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/media/imagen/ImagenViewModel.kt @@ -1,6 +1,7 @@ package com.google.firebase.quickstart.ai.feature.media.imagen import android.graphics.Bitmap +import android.graphics.BitmapFactory import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope @@ -8,8 +9,9 @@ import androidx.navigation.toRoute import com.google.firebase.Firebase import com.google.firebase.ai.ImagenModel import com.google.firebase.ai.ai -import com.google.firebase.ai.type.GenerativeBackend import com.google.firebase.ai.type.ImagenAspectRatio +import com.google.firebase.ai.type.ImagenEditMode +import com.google.firebase.ai.type.ImagenEditingConfig import com.google.firebase.ai.type.ImagenImageFormat import com.google.firebase.ai.type.ImagenPersonFilterLevel import com.google.firebase.ai.type.ImagenSafetyFilterLevel @@ -20,6 +22,7 @@ import com.google.firebase.ai.type.imagenGenerationConfig import com.google.firebase.quickstart.ai.FIREBASE_AI_SAMPLES import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch @OptIn(PublicPreviewAPI::class) @@ -36,6 +39,15 @@ class ImagenViewModel( private val _isLoading = MutableStateFlow(false) val isLoading: StateFlow = _isLoading + private val _includeAttach = MutableStateFlow(sample.includeAttach) + val includeAttach: StateFlow = _includeAttach + + private val _allowEmptyPrompt = MutableStateFlow(sample.allowEmptyPrompt) + val allowEmptyPrompt: StateFlow = _allowEmptyPrompt + + private val _attachedImage = MutableStateFlow(null) + val attachedImage: StateFlow = _attachedImage + private val _generatedBitmaps = MutableStateFlow(listOf()) val generatedBitmaps: StateFlow> = _generatedBitmaps @@ -45,7 +57,6 @@ class ImagenViewModel( init { val config = imagenGenerationConfig { numberOfImages = 4 - aspectRatio = ImagenAspectRatio.SQUARE_1x1 imageFormat = ImagenImageFormat.png() } val settings = ImagenSafetySettings( @@ -53,7 +64,7 @@ class ImagenViewModel( personFilterLevel = ImagenPersonFilterLevel.BLOCK_ALL ) imagenModel = Firebase.ai( - backend = GenerativeBackend.googleAI() + backend = sample.backend ).imagenModel( modelName = sample.modelName ?: "imagen-3.0-generate-002", generationConfig = config, @@ -65,9 +76,7 @@ class ImagenViewModel( viewModelScope.launch { _isLoading.value = true try { - val imageResponse = imagenModel.generateImages( - inputText - ) + val imageResponse = sample.generateImages!!(imagenModel, inputText, attachedImage.first()) _generatedBitmaps.value = imageResponse.images.map { it.asBitmap() } _errorMessage.value = null // clear error message } catch (e: Exception) { @@ -77,4 +86,10 @@ class ImagenViewModel( } } } + + suspend fun attachImage( + fileInBytes: ByteArray, + ) { + _attachedImage.emit(BitmapFactory.decodeByteArray(fileInBytes, 0, fileInBytes.size)) + } } diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt index 903f6a1114..1b7e76052b 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt @@ -1,8 +1,14 @@ package com.google.firebase.quickstart.ai.ui.navigation +import android.content.Context +import android.graphics.Bitmap +import com.google.firebase.ai.ImagenModel import com.google.firebase.ai.type.Content import com.google.firebase.ai.type.GenerationConfig import com.google.firebase.ai.type.GenerativeBackend +import com.google.firebase.ai.type.ImagenGenerationResponse +import com.google.firebase.ai.type.ImagenInlineImage +import com.google.firebase.ai.type.PublicPreviewAPI import com.google.firebase.ai.type.Tool import java.util.UUID @@ -17,6 +23,7 @@ enum class Category( FUNCTION_CALLING("Function calling"), } +@OptIn(PublicPreviewAPI::class) data class Sample( val id: String = UUID.randomUUID().toString(), // used for navigation val title: String, @@ -30,5 +37,8 @@ data class Sample( val systemInstructions: Content? = null, val generationConfig: GenerationConfig? = null, val chatHistory: List = emptyList(), - val tools: List? = null + val tools: List? = null, + val includeAttach: Boolean = false, + val allowEmptyPrompt: Boolean = false, + val generateImages: (suspend (ImagenModel, String, Bitmap?) -> ImagenGenerationResponse)? = null ) diff --git a/firebase-ai/app/src/main/res/drawable/cat.jpeg b/firebase-ai/app/src/main/res/drawable/cat.jpeg new file mode 100644 index 0000000000..0374ea894f Binary files /dev/null and b/firebase-ai/app/src/main/res/drawable/cat.jpeg differ