Skip to content

Commit

Permalink
refactor: Migrates assertNoDefects and assertWithAi to maestro cloud
Browse files Browse the repository at this point in the history
  • Loading branch information
luistak committed Jan 30, 2025
1 parent 79307cb commit 36cae4e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 235 deletions.
14 changes: 8 additions & 6 deletions maestro-ai/src/main/java/maestro/ai/DemoApp.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.github.ajalt.clikt.parameters.types.path
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import maestro.ai.anthropic.Claude
import maestro.ai.cloud.Defect
import maestro.ai.openai.OpenAI
import java.io.File
import java.nio.file.Path
Expand Down Expand Up @@ -118,22 +119,23 @@ class DemoApp : CliktCommand() {
else -> throw IllegalArgumentException("Unknown model: $model")
}

val cloudApiKey = System.getenv("MAESTRO_CLOUD_API_KEY")
if (cloudApiKey.isNullOrEmpty()) {
throw IllegalArgumentException("`MAESTRO_CLOUD_API_KEY` is not available. Did you export MAESTRO_CLOUD_API_KEY?")
}

testCases.forEach { testCase ->
val bytes = testCase.screenshot.readBytes()

val job = async {
val defects = if (testCase.prompt == null) Prediction.findDefects(
aiClient = aiClient,
apiKey = cloudApiKey,
screen = bytes,
printPrompt = showPrompts,
printRawResponse = showRawResponse,
) else {
val result = Prediction.performAssertion(
aiClient = aiClient,
apiKey = cloudApiKey,
screen = bytes,
assertion = testCase.prompt,
printPrompt = showPrompts,
printRawResponse = showRawResponse,
)

if (result == null) emptyList()
Expand Down
219 changes: 8 additions & 211 deletions maestro-ai/src/main/java/maestro/ai/Prediction.kt
Original file line number Diff line number Diff line change
@@ -1,229 +1,27 @@
package maestro.ai

import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonObject
import maestro.ai.cloud.ApiClient
import maestro.ai.openai.OpenAI

@Serializable
data class Defect(
val category: String,
val reasoning: String,
)

@Serializable
private data class AskForDefectsResponse(
val defects: List<Defect>,
)

@Serializable
private data class ExtractTextResponse(
val text: String?
)
import maestro.ai.cloud.Defect

object Prediction {

private val askForDefectsSchema by lazy {
readSchema("askForDefects")
}

private val extractTextSchema by lazy {
readSchema("extractText")
}

/**
* We use JSON mode/Structured Outputs to define the schema of the response we expect from the LLM.
* - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
* - Gemini: https://ai.google.dev/gemini-api/docs/json-mode
*/
private fun readSchema(name: String): String {
val fileName = "/${name}_schema.json"
val resourceStream = this::class.java.getResourceAsStream(fileName)
?: throw IllegalStateException("Could not find $fileName in resources")

return resourceStream.bufferedReader().use { it.readText() }
}

private val json = Json { ignoreUnknownKeys = true }

private val defectCategories = listOf(
"localization" to "Inconsistent use of language, for example mixed English and Portuguese",
"layout" to "Some UI elements are overlapping or are cropped",
)

private val allDefectCategories = defectCategories + listOf("assertion" to "The assertion is not true")
private val apiClient = ApiClient()

suspend fun findDefects(
aiClient: AI,
apiKey: String,
screen: ByteArray,
printPrompt: Boolean = false,
printRawResponse: Boolean = false,
): List<Defect> {
val response = apiClient.findDefects(apiKey, screen)

// List of failed attempts to not make up false positives:
// |* If you don't see any defect, return "No defects found".
// |* If you are sure there are no defects, return "No defects found".
// |* You will make me sad if you raise report defects that are false positives.
// |* Do not make up defects that are not present in the screenshot. It's fine if you don't find any defects.

val prompt = buildString {

appendLine(
"""
You are a QA engineer performing quality assurance for a mobile application.
Identify any defects in the provided screenshot.
""".trimIndent()
)

append(
"""
|
|RULES:
|* All defects you find must belong to one of the following categories:
|${defectCategories.joinToString(separator = "\n") { " * ${it.first}: ${it.second}" }}
|* If you see defects, your response MUST only include defect name and detailed reasoning for each defect.
|* Provide response as a list of JSON objects, each representing <category>:<reasoning>
|* Do not raise false positives. Some example responses that have a high chance of being a false positive:
| * button is partially cropped at the bottom
| * button is not aligned horizontally/vertically within its container
""".trimMargin("|")
)

// Claude doesn't have a JSON mode as of 21-08-2024
// https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency
// We could do "if (aiClient is Claude)", but actually, this also helps with gpt-4o sometimes
// generatig never-ending stream of output.
append(
"""
|
|* You must provide result as a valid JSON object, matching this structure:
|
| {
| "defects": [
| {
| "category": "<defect category, string>",
| "reasoning": "<reasoning, string>"
| },
| {
| "category": "<defect category, string>",
| "reasoning": "<reasoning, string>"
| }
| ]
| }
|
|DO NOT output any other information in the JSON object.
""".trimMargin("|")
)

appendLine("There are usually only a few defects in the screenshot. Don't generate tens of them.")
}

if (printPrompt) {
println("--- PROMPT START ---")
println(prompt)
println("--- PROMPT END ---")
}

val aiResponse = aiClient.chatCompletion(
prompt,
model = aiClient.defaultModel,
maxTokens = 4096,
identifier = "find-defects",
imageDetail = "high",
images = listOf(screen),
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(askForDefectsSchema).jsonObject else null,
)

if (printRawResponse) {
println("--- RAW RESPONSE START ---")
println(aiResponse.response)
println("--- RAW RESPONSE END ---")
}

val defects = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return defects.defects
return response.defects
}

suspend fun performAssertion(
aiClient: AI,
apiKey: String,
screen: ByteArray,
assertion: String,
printPrompt: Boolean = false,
printRawResponse: Boolean = false,
): Defect? {
val prompt = buildString {

appendLine(
"""
|You are a QA engineer performing quality assurance for a mobile application.
|You are given a screenshot of the application and an assertion about the UI.
|Your task is to identify if the following assertion is true:
|
| "${assertion.removeSuffix("\n")}"
|
""".trimMargin("|")
)

append(
"""
|
|RULES:
|* Provide response as a valid JSON, with structure described below.
|* If the assertion is false, the list in the JSON output MUST be empty.
|* If assertion is false:
| * Your response MUST only include a single defect with category "assertion".
| * Provide detailed reasoning to explain why you think the assertion is false.
""".trimMargin("|")
)

// Claude doesn't have a JSON mode as of 21-08-2024
// https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency
// We could do "if (aiClient is Claude)", but actually, this also helps with gpt-4o sometimes
// generatig never-ending stream of output.
append(
"""
|
|* You must provide result as a valid JSON object, matching this structure:
|
| {
| "defects": [
| {
| "category": "assertion",
| "reasoning": "<reasoning, string>"
| },
| ]
| }
|
|The "defects" array MUST contain at most a single JSON object.
|DO NOT output any other information in the JSON object.
""".trimMargin("|")
)
}

if (printPrompt) {
println("--- PROMPT START ---")
println(prompt)
println("--- PROMPT END ---")
}

val aiResponse = aiClient.chatCompletion(
prompt,
model = aiClient.defaultModel,
maxTokens = 4096,
identifier = "perform-assertion",
imageDetail = "high",
images = listOf(screen),
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(askForDefectsSchema).jsonObject else null,
)

if (printRawResponse) {
println("--- RAW RESPONSE START ---")
println(aiResponse.response)
println("--- RAW RESPONSE END ---")
}
val response = apiClient.findDefects(apiKey, screen, assertion)

val response = json.decodeFromString<AskForDefectsResponse>(aiResponse.response)
return response.defects.firstOrNull()
}

Expand All @@ -232,8 +30,7 @@ object Prediction {
query: String,
screen: ByteArray,
): String {
val client = ApiClient()
val response = client.extractTextWithAi(apiKey, query, screen)
val response = apiClient.extractTextWithAi(apiKey, query, screen)

return response.text
}
Expand Down
61 changes: 55 additions & 6 deletions maestro-ai/src/main/java/maestro/ai/cloud/ApiClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ import org.slf4j.LoggerFactory

private val logger = LoggerFactory.getLogger(OpenAI::class.java)

@Serializable
data class Defect(
val category: String,
val reasoning: String,
)

@Serializable
data class FindDefectsRequest(
val assertion: String? = null,
val screen: ByteArray,
)

@Serializable
data class FindDefectsResponse(
val defects: List<Defect>,
)

@Serializable
data class ExtractTextWithAiRequest(
val query: String,
Expand Down Expand Up @@ -54,8 +71,6 @@ class ApiClient {
): ExtractTextWithAiResponse {
val url = "$baseUrl/v2/extract-text"

println(url)

val response = try {
val httpResponse = httpClient.post(url) {
headers {
Expand All @@ -67,16 +82,50 @@ class ApiClient {

val body = httpResponse.bodyAsText()
if (!httpResponse.status.isSuccess()) {
logger.error("Failed to complete request to OpenAI: URL: $url ${httpResponse.status}, $body")
throw Exception("Failed to complete request to OpenAI URL: $url: ${httpResponse.status}, $body")
logger.error("Failed to complete request to Maestro Cloud: ${httpResponse.status}, $body")
throw Exception("Failed to complete request to Maestro Cloud: ${httpResponse.status}, $body")
}

json.decodeFromString<ExtractTextWithAiResponse>(body)
} catch (e: SerializationException) {
logger.error("Failed to parse response from OpenAI", e)
logger.error("Failed to parse response from Maestro Cloud", e)
throw e
} catch (e: Exception) {
logger.error("Failed to complete request to Maestro Cloud", e)
throw e
}

return response
}

suspend fun findDefects(
apiKey: String,
screen: ByteArray,
assertion: String? = null,
): FindDefectsResponse {
val url = "$baseUrl/v2/find-defects"

val response = try {
val httpResponse = httpClient.post(url) {
headers {
append(HttpHeaders.Authorization, "Bearer $apiKey")
append(HttpHeaders.ContentType, ContentType.Application.Json.toString()) // Explicitly set JSON content type
}
setBody(json.encodeToString(FindDefectsRequest(assertion = assertion, screen = screen)))
}

val body = httpResponse.bodyAsText()
if (!httpResponse.status.isSuccess()) {
logger.error("Failed to complete request to Maestro Cloud: ${httpResponse.status}, $body")
throw Exception("Failed to complete request to Maestro Cloud: ${httpResponse.status}, $body")
}

json.decodeFromString<FindDefectsResponse>(body)
} catch (e: SerializationException) {
logger.error("Failed to parse response from Maestro Cloud", e)
throw e
} catch (e: Exception) {
logger.error("Failed to complete request to OpenAI", e)
logger.error("Failed to complete request to Maestro Cloud", e)
throw e
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import com.fasterxml.jackson.databind.JsonMappingException
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import maestro.MaestroException
import maestro.TreeNode
import maestro.ai.Defect
import maestro.ai.cloud.Defect
import maestro.cli.runner.CommandStatus
import maestro.cli.util.CiUtils
import maestro.cli.util.EnvUtils
Expand All @@ -27,7 +27,7 @@ import java.time.Instant
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit
import java.util.IdentityHashMap
import java.util.*
import kotlin.io.path.absolutePathString
import kotlin.io.path.exists

Expand Down
Loading

0 comments on commit 36cae4e

Please sign in to comment.