Skip to content

Commit

Permalink
Merge pull request #110 from AikidoSec/AIK-4330
Browse files Browse the repository at this point in the history
AIK-4330 Java fix SSRF not working when IPC fails
  • Loading branch information
willem-delbare authored Jan 29, 2025
2 parents 8ad3e9c + 5064496 commit fa7a2b9
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 103 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/end2end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ jobs:
strategy:
matrix:
app:
- { name: SpringBootPostgres, test_file: end2end/spring_boot_postgres.py }
- { name: SpringBootMySQL, test_file: end2end/spring_boot_mysql.py }
- { name: SpringBootMSSQL, test_file: end2end/spring_boot_mssql.py }
- { name: SpringWebfluxSampleApp, test_file: end2end/spring_webflux_postgres.py }
- { name: SpringMVCPostgresKotlin, test_file: end2end/spring_mvc_postgres_kotlin.py }
- { name: SpringMVCPostgresGroovy, test_file: end2end/spring_mvc_postgres_groovy.py }
- { name: JavalinPostgres, test_file: end2end/javalin_postgres.py }
- { name: JavalinMySQLKotlin, test_file: end2end/javalin_mysql_kotlin.py }
- { name: SpringBootPostgres, test_file: end2end/spring_boot_postgres.py, with_ipc: true }
- { name: SpringBootPostgres, test_file: end2end/spring_boot_postgres.py, with_ipc: false }
- { name: SpringBootMySQL, test_file: end2end/spring_boot_mysql.py, with_ipc: true }
- { name: SpringBootMSSQL, test_file: end2end/spring_boot_mssql.py, with_ipc: true }
- { name: SpringWebfluxSampleApp, test_file: end2end/spring_webflux_postgres.py, with_ipc: true }
- { name: SpringMVCPostgresKotlin, test_file: end2end/spring_mvc_postgres_kotlin.py, with_ipc: true }
- { name: SpringMVCPostgresGroovy, test_file: end2end/spring_mvc_postgres_groovy.py, with_ipc: true }
- { name: JavalinPostgres, test_file: end2end/javalin_postgres.py, with_ipc: true }
- { name: JavalinMySQLKotlin, test_file: end2end/javalin_mysql_kotlin.py, with_ipc: true }
java-version: [17, 18, 19, 20, 21]
steps:
- name: Download build artifacts
Expand All @@ -63,6 +64,7 @@ jobs:
docker build -t mock_core .
docker run --name mock_core -d -p 5000:5000 mock_core
- name: Setup IPC Folder
if: matrix.app.with_ipc
run: mkdir /opt/aikido && chmod a+rw /opt/aikido
- name: Start databases
working-directory: ./sample-apps/databases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static void after(

try {
// Load the class from the JAR
Class<?> clazz = classLoader.loadClass("dev.aikido.agent_api.collectors.HostnameCollector");
Class<?> clazz = classLoader.loadClass("dev.aikido.agent_api.collectors.DNSRecordCollector");

// Run report with "argument"
for (Method method2: clazz.getMethods()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dev.aikido.agent_api.background.utilities.ThreadIPCClient;
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.thread_cache.ThreadCache;
import dev.aikido.agent_api.vulnerabilities.Attack;
import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFDetector;
import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException;
Expand All @@ -18,25 +17,25 @@
import static dev.aikido.agent_api.background.utilities.ThreadIPCClientFactory.getDefaultThreadIPCClient;
import static dev.aikido.agent_api.helpers.ShouldBlockHelper.shouldBlock;

public final class HostnameCollector {
private HostnameCollector() {}
private static final Logger logger = LogManager.getLogger(HostnameCollector.class);
public final class DNSRecordCollector {
private DNSRecordCollector() {}
private static final Logger logger = LogManager.getLogger(DNSRecordCollector.class);
public static void report(String hostname, InetAddress[] inetAddresses) {
try {
logger.trace("HostnameCollector called with %s & inet addresses: %s", hostname, List.of(inetAddresses));

logger.trace("DNSRecordCollector called with %s & inet addresses: %s", hostname, List.of(inetAddresses));

// Convert inetAddresses array to a List of IP strings :
List<String> ipAddresses = new ArrayList<>();
for (InetAddress inetAddress : inetAddresses) {
ipAddresses.add(inetAddress.getHostAddress());
}
// Currently using hostnames from thread cache, might not be as accurate as using Context-dependant hostnames.
if (ThreadCache.get() == null || ThreadCache.get().getHostnames() == null) {
logger.trace("Thread cache is empty, returning.");

// Fetch hostnames from Context (this is to get port number e.g.)
if (Context.get() == null || Context.get().getHostnames() == null) {
logger.trace("Context not defined, returning.");
return;
}
for (Hostnames.HostnameEntry hostnameEntry : ThreadCache.get().getHostnames().asArray()) {
for (Hostnames.HostnameEntry hostnameEntry : Context.get().getHostnames().asArray()) {
if (!hostnameEntry.getHostname().equals(hostname)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.aikido.agent_api.collectors;

import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.thread_cache.ThreadCache;
import dev.aikido.agent_api.thread_cache.ThreadCacheObject;
import dev.aikido.agent_api.helpers.logging.LogManager;
Expand All @@ -14,15 +16,29 @@ public final class URLCollector {

private URLCollector() {}
public static void report(URL url) {
ThreadCacheObject threadCache = ThreadCache.get();
if(threadCache != null && url != null) {
if(url != null) {
if (!url.getProtocol().startsWith("http")) {
return; // Non-HTTP(S) URL
}
logger.trace("Adding a new URL to the cache: %s", url);
int port = getPortFromURL(url);
threadCache.getHostnames().add(url.getHost(), port);
ThreadCache.set(threadCache);

// We store hostname and port in two places, Thread Cache and Context. Thread Cache is for reporting
// outbound domains. Context is to have a map of hostnames with used port numbers to detect SSRF attacks.

// Add to thread cache :
ThreadCacheObject threadCache = ThreadCache.get();
if (threadCache != null) {
threadCache.getHostnames().add(url.getHost(), port);
ThreadCache.set(threadCache);
}

// Add to context :
ContextObject context = Context.get();
if (context != null) {
context.getHostnames().add(url.getHost(), port);
Context.set(context);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.storage.RedirectNode;

import java.util.ArrayList;
Expand All @@ -26,6 +27,10 @@ public class ContextObject {
protected transient ArrayList<RedirectNode> redirectStartNodes;
protected transient Map<String, Map<String, String>> cache = new HashMap<>();

// We store hostnames in the context object so we can match a given hostname (by DNS request)
// with its port number (which we know by instrumenting the URLs that get requested).
protected transient Hostnames hostnames = new Hostnames(1000); // max 1000 entries

public boolean middlewareExecuted() {return executedMiddleware; }
public void setExecutedMiddleware(boolean value) { executedMiddleware = value; }

Expand Down Expand Up @@ -69,6 +74,7 @@ public HashMap<String, List<String>> getCookies() {
return cookies;
}
public Map<String, Map<String, String>> getCache() { return cache; }
public Hostnames getHostnames() { return hostnames; }

public String toJson() {
Gson gson = new GsonBuilder().setPrettyPrinting().create();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package collectors;

import dev.aikido.agent_api.collectors.HostnameCollector;
import dev.aikido.agent_api.collectors.DNSRecordCollector;
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.thread_cache.ThreadCache;
import dev.aikido.agent_api.thread_cache.ThreadCacheObject;
import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException;
import org.junit.jupiter.api.*;
import org.junitpioneer.jupiter.SetEnvironmentVariable;
Expand All @@ -14,12 +14,11 @@
import java.net.UnknownHostException;
import java.util.List;

import static dev.aikido.agent_api.helpers.UnixTimeMS.getUnixTimeMS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.*;

public class HostnameCollectorTest {
public class DNSRecordCollectorTest {
InetAddress inetAddress1;
InetAddress inetAddress2;
@BeforeEach
Expand All @@ -32,52 +31,48 @@ void setup() throws UnknownHostException {
@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "token")
@Test
public void testThreadCacheNull() {
// Early return because of Thread Cache being null :
HostnameCollector.report("dev.aikido", new InetAddress[]{
// Early return because of Context being null :
DNSRecordCollector.report("dev.aikido", new InetAddress[]{
inetAddress1, inetAddress2
});
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "token")
@Test
public void testThreadCacheHostnames() {
ThreadCacheObject myThreadCache = mock(ThreadCacheObject.class);
when(myThreadCache.getLastRenewedAtMS()).thenReturn(getUnixTimeMS());
ThreadCache.set(myThreadCache);
HostnameCollector.report("dev.aikido", new InetAddress[]{
ContextObject myContextObject = mock(ContextObject.class);
Context.set(myContextObject);
DNSRecordCollector.report("dev.aikido", new InetAddress[]{
inetAddress1, inetAddress2
});
verify(myThreadCache).getHostnames();

myThreadCache = mock(ThreadCacheObject.class);
when(myThreadCache.getLastRenewedAtMS()).thenReturn(getUnixTimeMS());
verify(myContextObject).getHostnames();

myContextObject = mock(ContextObject.class);
Hostnames hostnames = new Hostnames(20);
when(myThreadCache.getHostnames()).thenReturn(hostnames);
when(myContextObject.getHostnames()).thenReturn(hostnames);

Context.set(myContextObject);

ThreadCache.set(myThreadCache);
HostnameCollector.report("dev.aikido", new InetAddress[]{
DNSRecordCollector.report("dev.aikido", new InetAddress[]{
inetAddress1, inetAddress2
});
verify(myThreadCache, times(2)).getHostnames();
verify(myContextObject, times(2)).getHostnames();
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "token")
@Test
public void testHostnameSame() {
ThreadCacheObject myThreadCache = mock(ThreadCacheObject.class);
when(myThreadCache.getLastRenewedAtMS()).thenReturn(getUnixTimeMS());

ContextObject myContextObject = mock(ContextObject.class);
Hostnames hostnames = new Hostnames(20);
hostnames.add("dev.aikido.not", 80);
hostnames.add("dev.aikido", 80);
when(myThreadCache.getHostnames()).thenReturn(hostnames);
when(myContextObject.getHostnames()).thenReturn(hostnames);

ThreadCache.set(myThreadCache);
HostnameCollector.report("dev.aikido", new InetAddress[]{
Context.set(myContextObject);
DNSRecordCollector.report("dev.aikido", new InetAddress[]{
inetAddress1, inetAddress2
});
verify(myThreadCache, times(2)).getHostnames();
verify(myContextObject, times(2)).getHostnames();
}

public static class SampleContextObject extends EmptySampleContextObject {
Expand All @@ -92,22 +87,16 @@ public SampleContextObject() {
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "1")
@Test
public void testHostnameSameWithContextAsAttack() {
ThreadCacheObject myThreadCache = mock(ThreadCacheObject.class);
when(myThreadCache.getLastRenewedAtMS()).thenReturn(getUnixTimeMS());

Hostnames hostnames = new Hostnames(20);
hostnames.add("dev.aikido.not", 80);
hostnames.add("dev.aikido", 80);
when(myThreadCache.getHostnames()).thenReturn(hostnames);
ContextObject myContextObject = new SampleContextObject();
myContextObject.getHostnames().add("dev.aikido.not", 80);
myContextObject.getHostnames().add("dev.aikido", 80);
Context.set(myContextObject);

ThreadCache.set(myThreadCache);
Context.set(new SampleContextObject());
Exception exception = assertThrows(SSRFException.class, () -> {
HostnameCollector.report("dev.aikido", new InetAddress[]{
DNSRecordCollector.report("dev.aikido", new InetAddress[]{
inetAddress1, inetAddress2
});
});
verify(myThreadCache, times(2)).getHostnames();
assertEquals("Aikido Zen has blocked a server-side request forgery", exception.getMessage());
}

Expand Down
53 changes: 53 additions & 0 deletions agent_api/src/test/java/collectors/URLCollectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.net.URL;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static utils.EmtpyThreadCacheObject.getEmptyThreadCacheObject;

public class URLCollectorTest {
Expand Down Expand Up @@ -55,6 +56,11 @@ public void testNewUrlConnectionWithHttp() throws IOException {
assertEquals(1, hostnameArray.length);
assertEquals(80, hostnameArray[0].getPort());
assertEquals("app.local.aikido.io", hostnameArray[0].getHostname());

Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
assertEquals(1, hostnameArray2.length);
assertEquals(80, hostnameArray2[0].getPort());
assertEquals("app.local.aikido.io", hostnameArray2[0].getHostname());
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
Expand All @@ -67,6 +73,11 @@ public void testNewUrlConnectionHttps() throws IOException {
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
assertEquals("aikido.dev", hostnameArray[0].getHostname());

Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
assertEquals(1, hostnameArray2.length);
assertEquals(443, hostnameArray2[0].getPort());
assertEquals("aikido.dev", hostnameArray2[0].getHostname());
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
Expand All @@ -77,5 +88,47 @@ public void testNewUrlConnectionFaultyProtocol() throws IOException {
URLCollector.report(new URL("ftp://localhost:8080"));
Hostnames.HostnameEntry[] hostnameArray = ThreadCache.get().getHostnames().asArray();
assertEquals(0, hostnameArray.length);
Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
assertEquals(0, hostnameArray2.length);
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
public void testWithNullURL() throws IOException {
setContextAndLifecycle("");
URLCollector.report(null);
Hostnames.HostnameEntry[] hostnameArray = ThreadCache.get().getHostnames().asArray();
assertEquals(0, hostnameArray.length);
Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
assertEquals(0, hostnameArray2.length);
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
public void testWithNullContext() throws IOException {
setContextAndLifecycle("");
Context.reset();
URLCollector.report(new URL("https://aikido.dev"));
Hostnames.HostnameEntry[] hostnameArray = ThreadCache.get().getHostnames().asArray();
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
assertEquals("aikido.dev", hostnameArray[0].getHostname());
assertNull(Context.get());
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
public void testWithNullThreadCache() throws IOException {
setContextAndLifecycle("");
ThreadCache.reset();
URLCollector.report(new URL("https://aikido.dev"));
Hostnames.HostnameEntry[] hostnameArray = Context.get().getHostnames().asArray();
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
assertEquals("aikido.dev", hostnameArray[0].getHostname());
assertNull(ThreadCache.get());
}
}
11 changes: 0 additions & 11 deletions agent_api/src/test/java/wrappers/ApacheHttpClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ public void testSSRFWithoutPortAndWithoutContext() throws Exception {
});
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token-2")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
public void testSSRFWithoutPortAndWithoutThreadCache() throws Exception {
setContextAndLifecycle("http://localhost:80");
ThreadCache.set(null);
assertThrows(ConnectException.class, () -> {
fetchResponse("http://localhost/weirdroute");
});
}

private void fetchResponse(String urlString) throws IOException {
HttpGet request = new HttpGet(urlString);
request.addHeader("Authorization", "Bearer invalid-token-2");
Expand Down
11 changes: 0 additions & 11 deletions agent_api/src/test/java/wrappers/InetAddressTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ public void testSSRFWithoutPortAndWithoutContext() throws Exception {
});
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token-2")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
public void testSSRFWithoutPortAndWithoutThreadCache() throws Exception {
setContextAndLifecycle("http://localhost:80");
ThreadCache.set(null);
assertThrows(ConnectException.class, () -> {
fetchResponse("http://localhost/weirdroute");
});
}

@SetEnvironmentVariable(key = "AIKIDO_TOKEN", value = "invalid-token-2")
@SetEnvironmentVariable(key = "AIKIDO_BLOCK", value = "true")
@Test
Expand Down
Loading

0 comments on commit fa7a2b9

Please sign in to comment.