Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion goose-android-agent/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.kotlin.android)
alias(libs.plugins.kotlin.compose)
kotlin("plugin.serialization") version "1.9.22"
kotlin("plugin.serialization") version "2.2.0"
}

android {
Expand Down Expand Up @@ -41,6 +41,7 @@ android {

testOptions {
unitTests.isReturnDefaultValues = true
unitTests.isIncludeAndroidResources = true
}

lint {
Expand All @@ -49,6 +50,19 @@ android {
}
}

// Unit tests run on the debug variant only. The release variant's merged manifest
// doesn't include the test-only ComponentActivity (contributed by the
// debugImplementation-only ui-test-manifest aar), so Compose UI tests can't run
// against it. Following the same pattern as NowInAndroid, Tivi, and the AndroidX
// reference apps, we disable testReleaseUnitTest entirely.
androidComponents {
beforeVariants(selector().withBuildType("release")) { variant ->
(variant as com.android.build.api.variant.HasHostTestsBuilder)
.hostTests[com.android.build.api.variant.HostTestBuilder.UNIT_TEST_TYPE]
?.enable = false
}
}

dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.lifecycle.runtime.ktx)
Expand All @@ -66,6 +80,8 @@ dependencies {
testImplementation(libs.robolectric)
testImplementation(libs.kotlin.test)
testImplementation(libs.androidx.test.core)
testImplementation(platform(libs.androidx.compose.bom))
testImplementation(libs.androidx.ui.test.junit4)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(platform(libs.androidx.compose.bom))
Expand All @@ -84,4 +100,7 @@ dependencies {

// ML Kit for barcode scanning
implementation("com.google.mlkit:barcode-scanning:17.2.0")

// LiteRT-LM for on-device LLM inference
implementation("com.google.ai.edge.litertlm:litertlm-android:0.10.0")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package xyz.block.gosling

import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import xyz.block.gosling.features.agent.AiModel
import xyz.block.gosling.features.agent.ModelProvider
import xyz.block.gosling.features.agent.ondevice.OnDeviceModelManager
import xyz.block.gosling.features.agent.ondevice.LiteRTInference

@RunWith(AndroidJUnit4::class)
class OnDeviceModelTest {

private val context by lazy {
InstrumentationRegistry.getInstrumentation().targetContext
}

companion object {
private const val TAG = "OnDeviceModelTest"
}

@Before
fun setup() {
// Force reload model config so we pick up the latest JSON
OnDeviceModelManager.reloadModelsConfig(context)
}

@Test
fun knownModels_containsGemma4Only() {
val models = OnDeviceModelManager.getKnownModels(context)

Log.i(TAG, "Known models: ${models.map { it.id }}")
assertEquals("Should have exactly 2 models", 2, models.size)
assertTrue(
"Should contain gemma4-e2b",
models.any { it.id == "on-device/gemma4-e2b" }
)
assertTrue(
"Should contain gemma4-e4b",
models.any { it.id == "on-device/gemma4-e4b" }
)
}

@Test
fun downloadedModels_containsGemma4E2B() {
val downloaded = OnDeviceModelManager.getDownloadedModels(context)

Log.i(TAG, "Downloaded models: ${downloaded.map { it.id }}")
assertTrue(
"Gemma 4 E2B should be downloaded",
downloaded.any { it.id == "on-device/gemma4-e2b" }
)
}

@Test
fun modelRegistration_registersDownloadedModels() {
OnDeviceModelManager.registerDownloadedModels(context)

val onDeviceModels = AiModel.getModelsForProvider(ModelProvider.ON_DEVICE_LITERT)
Log.i(TAG, "Registered on-device models: ${onDeviceModels.map { it.identifier }}")
assertTrue(
"Should have at least one registered on-device model",
onDeviceModels.isNotEmpty()
)
assertTrue(
"Gemma 4 E2B should be registered",
onDeviceModels.any { it.identifier == "on-device/gemma4-e2b" }
)
}

@Test
fun modelPath_resolvesForDownloadedModel() {
val path = OnDeviceModelManager.getModelPath(context, "on-device/gemma4-e2b")

Log.i(TAG, "Model path: $path")
assertNotNull("Model path should not be null", path)
assertTrue("Model path should end with .litertlm", path!!.endsWith(".litertlm"))
}

@Test
fun contextLength_correctForGemma4() {
val ctxE2B = OnDeviceModelManager.getContextLength(context, "on-device/gemma4-e2b")
val ctxE4B = OnDeviceModelManager.getContextLength(context, "on-device/gemma4-e4b")

assertEquals("Gemma 4 E2B context should be 32000", 32000, ctxE2B)
assertEquals("Gemma 4 E4B context should be 32000", 32000, ctxE4B)
}

@Test
fun liteRTInference_isAvailable() {
assertTrue("LiteRT-LM should be available", LiteRTInference.isAvailable())
}

@Test
fun liteRTInference_initializeAndChat() {
val modelPath = OnDeviceModelManager.getModelPath(context, "on-device/gemma4-e2b")
assertNotNull("Model must be downloaded to run this test", modelPath)

Log.i(TAG, "Initializing LiteRT engine with model: $modelPath")
LiteRTInference.initialize(modelPath!!, context.cacheDir.path)

Log.i(TAG, "Creating conversation...")
val conversation = LiteRTInference.createConversation(
systemInstruction = "You are a helpful assistant. Reply briefly.",
tools = emptyList()
)

Log.i(TAG, "Sending test message...")
val response = conversation.sendMessage("Say hello in one sentence.")
val responseText = response.toString()

Log.i(TAG, "Response: $responseText")
assertTrue(
"Response should not be empty",
responseText.isNotBlank()
)

conversation.close()
LiteRTInference.close()
Log.i(TAG, "Test complete")
}
}
18 changes: 18 additions & 0 deletions goose-android-agent/app/src/main/assets/models_litert.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[
{
"id": "on-device/gemma4-e2b",
"displayName": "Gemma 4 E2B",
"fileName": "gemma-4-E2B-it.litertlm",
"downloadUrl": "https://huggingface.co/litert-community/gemma-4-E2B-it-litert-lm/resolve/7fa1d78473894f7e736a21d920c3aa80f950c0db/gemma-4-E2B-it.litertlm?download=true",
"sizeBytes": 2583085056,
"contextLength": 32000
},
{
"id": "on-device/gemma4-e4b",
"displayName": "Gemma 4 E4B",
"fileName": "gemma-4-E4B-it.litertlm",
"downloadUrl": "https://huggingface.co/litert-community/gemma-4-E4B-it-litert-lm/resolve/9695417f248178c63a9f318c6e0c56cb917cb837/gemma-4-E4B-it.litertlm?download=true",
"sizeBytes": 3654467584,
"contextLength": 32000
}
]
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package xyz.block.gosling

import android.app.Application
import xyz.block.gosling.features.agent.ondevice.OnDeviceModelManager
import xyz.block.gosling.features.overlay.OverlayService

class GoslingApplication : Application() {
override fun onCreate() {
super.onCreate()
// Register any previously downloaded on-device models so
// AiModel.fromIdentifier() can resolve them after app restart.
OnDeviceModelManager.registerDownloadedModels(this)
}

companion object {
var isMainActivityRunning = false
var isLauncherActivityRunning = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.json.JSONObject
import xyz.block.gosling.features.accessibility.GoslingAccessibilityService
import xyz.block.gosling.features.agent.ToolHandler.callTool
import xyz.block.gosling.features.agent.ToolHandler.getSerializableToolDefinitions
import xyz.block.gosling.features.agent.ondevice.OnDeviceModelManager
import xyz.block.gosling.features.agent.providers.LiteRTProviderHandler
import xyz.block.gosling.features.settings.SettingsStore
import java.io.File
import java.net.HttpURLConnection
Expand Down Expand Up @@ -696,30 +698,60 @@ class Agent : Service() {
private suspend fun callLlm(messages: List<Message>, context: Context): JSONObject {
val settings = SettingsStore(context)
val model = AiModel.fromIdentifier(settings.llmModel)

val processedMessages = removeOutdatedPayloads(messages)

// Get the appropriate provider handler
val providerHandler = getProviderHandler(model.provider)

// Local inference path - no HTTP, no API key needed
if (providerHandler.isLocalProvider()) {
val toolDefinitions = getSerializableToolDefinitions(context, model.provider)

val (text, toolCalls, stats) = providerHandler.executeLocal(
model.identifier,
processedMessages,
toolDefinitions
)

// Build a synthetic JSONObject so downstream parsing in processCommand() works
val syntheticResponse = JSONObject()
syntheticResponse.put("text", text)
if (toolCalls != null) {
val tcArray = org.json.JSONArray()
for (tc in toolCalls) {
val tcObj = JSONObject()
tcObj.put("id", tc.toolId)
tcObj.put("name", tc.name)
tcObj.put("arguments", tc.arguments)
tcArray.put(tcObj)
}
syntheticResponse.put("tool_calls", tcArray)
}
syntheticResponse.put("duration", stats["duration"] ?: 0.0)
return syntheticResponse
}

// HTTP inference path - requires API key
val apiKey = settings.getApiKey(model.provider)

// Check for empty API key early
if (apiKey.isNullOrBlank()) {
updateStatus(AgentStatus.Error("API key is missing. Please add your API key in settings."))
throw ApiKeyException("API key is missing. Please add your API key in settings.")
}

val processedMessages = removeOutdatedPayloads(messages)

// Get the appropriate provider handler
val providerHandler = getProviderHandler(model.provider)

// Get tool definitions using the provider handler
val toolDefinitions = getSerializableToolDefinitions(context, model.provider)

// Create request using provider handler
val requestBody = providerHandler.createRequest(
model.identifier,
processedMessages,
toolDefinitions,
apiKey
)

// Get URL and headers from provider handler
val urlString = providerHandler.getApiUrl(model.identifier, apiKey)
val headers = providerHandler.getHeaders(apiKey)
Expand All @@ -737,6 +769,12 @@ class Agent : Service() {
ModelProvider.OPENAI -> xyz.block.gosling.features.agent.providers.OpenAIProviderHandler()
ModelProvider.GEMINI -> xyz.block.gosling.features.agent.providers.GeminiProviderHandler()
ModelProvider.OPENROUTER -> xyz.block.gosling.features.agent.providers.OpenRouterProviderHandler()
ModelProvider.ON_DEVICE_LITERT -> {
val settings = SettingsStore(this)
val model = AiModel.fromIdentifier(settings.llmModel)
val modelPath = OnDeviceModelManager.getModelPath(this, model.identifier)
LiteRTProviderHandler(modelPath = modelPath, cacheDir = cacheDir.path)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package xyz.block.gosling.features.agent

enum class ModelProvider {
OPENAI,
GEMINI,
OPENROUTER
enum class ModelProvider(val displayName: String) {
OPENAI("OpenAI"),
GEMINI("Gemini"),
OPENROUTER("OpenRouter"),
ON_DEVICE_LITERT("On-Device");

val isOnDevice: Boolean
get() = this == ON_DEVICE_LITERT
}

data class AiModel(
Expand Down Expand Up @@ -36,15 +40,29 @@ data class AiModel(
AiModel("Cohere Command R+", "cohere/command-r-plus", ModelProvider.OPENROUTER)
)

private val onDeviceModels = mutableListOf<AiModel>()

fun registerOnDeviceModel(model: AiModel) {
if (onDeviceModels.none { it.identifier == model.identifier }) {
onDeviceModels.add(model)
}
}

fun unregisterOnDeviceModel(identifier: String) {
onDeviceModels.removeAll { it.identifier == identifier }
}

fun getAllModels(): List<AiModel> = AVAILABLE_MODELS + onDeviceModels

fun fromIdentifier(identifier: String): AiModel {
return AVAILABLE_MODELS.find { it.identifier == identifier }
return getAllModels().find { it.identifier == identifier }
?: AVAILABLE_MODELS.first()
}

fun getProviders(): List<ModelProvider> =
AVAILABLE_MODELS.map { it.provider }.distinct()
fun getModelsForProvider(provider: ModelProvider): List<AiModel> =
AVAILABLE_MODELS.filter { it.provider == provider }
fun getProviders(): List<ModelProvider> =
ModelProvider.entries.toList()

fun getModelsForProvider(provider: ModelProvider): List<AiModel> =
getAllModels().filter { it.provider == provider }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,7 @@ object ToolHandler {
ModelProvider.OPENAI -> xyz.block.gosling.features.agent.providers.OpenAIProviderHandler()
ModelProvider.GEMINI -> xyz.block.gosling.features.agent.providers.GeminiProviderHandler()
ModelProvider.OPENROUTER -> xyz.block.gosling.features.agent.providers.OpenRouterProviderHandler()
ModelProvider.ON_DEVICE_LITERT -> xyz.block.gosling.features.agent.providers.LiteRTProviderHandler()
}
}

Expand Down
Loading