Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ interface RemoteConfigDataSource {
fun getPromoVideoLink(): String

fun getDancingDroidLink(): String

fun useImagen(): Boolean

fun getFineTunedModelName(): String
}

@Singleton
Expand Down Expand Up @@ -83,4 +87,11 @@ class RemoteConfigDataSourceImpl @Inject constructor() : RemoteConfigDataSource
override fun getDancingDroidLink(): String {
return remoteConfig.getString("dancing_droid_gif_link")
}

override fun useImagen(): Boolean {
return remoteConfig.getBoolean("use_imagen")
}
override fun getFineTunedModelName(): String {
return remoteConfig.getString("fine_tuned_model_name")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.google.firebase.ai.type.ImagenSafetySettings
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.SafetySetting
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.asImageOrNull
import com.google.firebase.ai.type.content
import com.google.firebase.ai.type.generationConfig
import kotlinx.serialization.json.Json
Expand Down Expand Up @@ -138,21 +139,36 @@ class FirebaseAiDataSourceImpl @Inject constructor(
image,
)
}
private fun createFineTunedModel(): GenerativeModel {
return Firebase.ai.generativeModel(
remoteConfigDataSource.getFineTunedModelName(),
safetySettings = listOf(
SafetySetting(HarmCategory.HARASSMENT, HarmBlockThreshold.LOW_AND_ABOVE),
SafetySetting(HarmCategory.HATE_SPEECH, HarmBlockThreshold.LOW_AND_ABOVE),
SafetySetting(HarmCategory.SEXUALLY_EXPLICIT, HarmBlockThreshold.LOW_AND_ABOVE),
SafetySetting(HarmCategory.DANGEROUS_CONTENT, HarmBlockThreshold.LOW_AND_ABOVE),
SafetySetting(HarmCategory.CIVIC_INTEGRITY, HarmBlockThreshold.LOW_AND_ABOVE),
),
)
}

override suspend fun generateImageFromPromptAndSkinTone(prompt: String, skinTone: String): Bitmap {
val generativeModel = createGenerativeImageModel()
// Retrieve the base prompt template from Remote Config
val basePromptTemplate = remoteConfigDataSource.promptImageGenerationWithSkinTone()

// Perform the substitution
val imageGenerationPrompt = basePromptTemplate
.replace("{prompt}", prompt)
.replace("{skinTone}", skinTone)

return executeImageGeneration(
generativeModel,
imageGenerationPrompt,
)
if (remoteConfigDataSource.useImagen()) {
val generativeModel = createGenerativeImageModel()
return executeImageGeneration(
generativeModel,
imageGenerationPrompt,
)
} else {
val fineTunedModel = createFineTunedModel()
val response = fineTunedModel.generateContent(imageGenerationPrompt)
return response.candidates.firstOrNull()?.content?.parts?.firstOrNull()?.asImageOrNull()
?: throw IllegalStateException("Could not extract image from fine-tuned model response")
}
}

private suspend fun executeTextValidation(
Expand Down
Loading