Skip to content

Commit

Permalink
Allow test Kafka Clusters to use SASL SCRAM-SHA and OAUTH bearer
Browse files Browse the repository at this point in the history
wip - only KafkaCluster API refactored at the moment.

Signed-off-by: kwall <kwall@apache.org>
  • Loading branch information
k-wall committed May 29, 2024
1 parent 237f9b4 commit 9bb2ca6
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
Expand All @@ -31,6 +32,9 @@
import org.apache.kafka.common.config.SslConfigs;
import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.plain.PlainLoginModule;
import org.apache.kafka.common.security.scram.ScramLoginModule;
import org.apache.kafka.common.utils.AppInfoParser;
import org.junit.jupiter.api.TestInfo;

Expand Down Expand Up @@ -61,6 +65,9 @@ public class KafkaClusterConfig {
public static final String INTERNAL_LISTENER_NAME = "INTERNAL";
public static final String ANON_LISTENER_NAME = "ANON";

private static final String SASL_SCRAM_SHA_MECHANISM_PREFIX = "SCRAM-SHA-";
private static final String SASL_PLAIN_MECHANISM_NAME = "PLAIN";

private TestInfo testInfo;
private KeytoolCertificateGenerator brokerKeytoolCertificateGenerator;
private KeytoolCertificateGenerator clientKeytoolCertificateGenerator;
Expand Down Expand Up @@ -90,6 +97,13 @@ public class KafkaClusterConfig {
* will be used.
*/
private final String saslMechanism;

/**
* name of login module that will be used to for client and broker. if null, the login module will be
* derived from the saslMechanism.
*/
private String loginModule;

private final String securityProtocol;
@Builder.Default
private Integer brokersNum = 1;
Expand All @@ -105,6 +119,12 @@ public class KafkaClusterConfig {
@Singular
private final Map<String, String> users;

@Singular
private final Map<String, String> jaasServerOptions;

@Singular
private final Map<String, String> jaasClientOptions;

@Singular
private final Map<String, String> brokerConfigs;

Expand Down Expand Up @@ -164,7 +184,7 @@ public static KafkaClusterConfig fromConstraints(List<Annotation> annotations, T
}
}
if (annotation instanceof SaslPlainAuth.List saslPlainAuthList) {
builder.saslMechanism("PLAIN");
builder.saslMechanism(SASL_PLAIN_MECHANISM_NAME);
sasl = true;
Map<String, String> users = new HashMap<>();
for (var user : saslPlainAuthList.value()) {
Expand All @@ -173,7 +193,7 @@ public static KafkaClusterConfig fromConstraints(List<Annotation> annotations, T
builder.users(users);
}
else if (annotation instanceof SaslPlainAuth saslPlainAuth) {
builder.saslMechanism("PLAIN");
builder.saslMechanism(SASL_PLAIN_MECHANISM_NAME);
sasl = true;
builder.users(Map.of(saslPlainAuth.user(), saslPlainAuth.password()));
}
Expand Down Expand Up @@ -361,18 +381,43 @@ private void configureSasl(Properties server) {
if (saslMechanism != null) {
putConfig(server, "sasl.enabled.mechanisms", saslMechanism);

var saslPairs = new StringBuilder();
var lm = loginModule;
if (lm == null) {
lm = deriveLoginModuleFromSasl(saslMechanism);
}

var serverOptions = Optional.ofNullable(jaasServerOptions).orElse(Map.of()).entrySet().stream();
Stream<Map.Entry<String, String>> userOptions = Stream.empty();
// Note Scram users are added to the server at after startup.
if (isSaslPlain()) {
userOptions = Optional.ofNullable(users).orElse(Map.of()).entrySet()
.stream()
.collect(Collectors.toMap(e -> String.format("user_%s", e.getKey()), Map.Entry::getValue)).entrySet().stream();
}

var moduleOptions = Stream.concat(serverOptions, userOptions)
.map(e -> String.join("=", e.getKey(), e.getValue()))
.collect(Collectors.joining(" "));

Optional.ofNullable(users).orElse(Map.of()).forEach((key, value) -> {
saslPairs.append(String.format("user_%s", key));
saslPairs.append("=");
saslPairs.append(value);
saslPairs.append(" ");
});
var moduleConfig = String.format("%s required %s;", lm, moduleOptions);
var configKey = String.format("listener.name.%s.%s.sasl.jaas.config", EXTERNAL_LISTENER_NAME.toLowerCase(), saslMechanism.toLowerCase(Locale.ROOT));

// TODO support other than PLAIN
String plainModuleConfig = String.format("org.apache.kafka.common.security.plain.PlainLoginModule required %s;", saslPairs);
putConfig(server, String.format("listener.name.%s.plain.sasl.jaas.config", EXTERNAL_LISTENER_NAME.toLowerCase()), plainModuleConfig);
putConfig(server, configKey, moduleConfig);
}
}

private String deriveLoginModuleFromSasl(String saslMechanism) {
switch (saslMechanism.toUpperCase(Locale.ROOT)) {
case SASL_PLAIN_MECHANISM_NAME -> {
return PlainLoginModule.class.getName();
}
case "SCRAM-SHA-256", "SCRAM-SHA-512" -> {
return ScramLoginModule.class.getName();
}
case "OAUTHBEARER" -> {
return OAuthBearerLoginModule.class.getName();
}
default -> throw new IllegalArgumentException("Cannot derive login module from saslMechanism %s".formatted(saslMechanism));
}
}

Expand Down Expand Up @@ -453,9 +498,9 @@ public Map<String, Object> getAnonConnectConfigForCluster(String bootstrapServer
*/
public Map<String, Object> getConnectConfigForCluster(String bootstrapServers) {
if (saslMechanism != null) {
Map<String, String> users = getUsers();
if (!users.isEmpty()) {
Map.Entry<String, String> first = users.entrySet().iterator().next();
var externalUsers = getUsers();
if (!externalUsers.isEmpty()) {
Map.Entry<String, String> first = externalUsers.entrySet().iterator().next();
return getConnectConfigForCluster(bootstrapServers, first.getKey(), first.getValue(), getSecurityProtocol(), getSaslMechanism());
}
else {
Expand Down Expand Up @@ -536,28 +581,48 @@ public Map<String, Object> getConnectConfigForCluster(String bootstrapServers, S
}

if (saslMechanism != null) {
kafkaConfig.put(SaslConfigs.SASL_MECHANISM, saslMechanism);

var lm = loginModule;
if (lm == null) {
lm = deriveLoginModuleFromSasl(saslMechanism);
}

if (securityProtocol == null) {
kafkaConfig.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, SecurityProtocol.SASL_PLAINTEXT.name());
}
kafkaConfig.put(SaslConfigs.SASL_MECHANISM, saslMechanism);

if ("PLAIN".equals(saslMechanism)) {
if (user != null && password != null) {
kafkaConfig.put(SaslConfigs.SASL_JAAS_CONFIG,
String.format("org.apache.kafka.common.security.plain.PlainLoginModule required username=\"%s\" password=\"%s\";",
user, password));
var jaasOptions = new HashMap<>(jaasClientOptions == null ? Map.of() : jaasClientOptions);

if (isSaslPlain() || isSaslScram()) {
if (user != null && !jaasOptions.containsKey("username")) {
jaasOptions.put("username", user);
}
if (password != null && !jaasOptions.containsKey("password")) {
jaasOptions.put("password", password);
}
}
else {
throw new IllegalStateException(String.format("unsupported SASL mechanism %s", saslMechanism));
}

var moduleOptions = jaasOptions.entrySet().stream()
.map(e -> String.join("=", e.getKey(), e.getValue()))
.collect(Collectors.joining(" "));

kafkaConfig.put(SaslConfigs.SASL_JAAS_CONFIG, String.format("%s required %s;", lm, moduleOptions));
}

kafkaConfig.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);

return kafkaConfig;
}

private boolean isSaslPlain() {
return this.saslMechanism != null && this.saslMechanism.toUpperCase(Locale.ROOT).equals(SASL_PLAIN_MECHANISM_NAME);
}

public boolean isSaslScram() {
return this.saslMechanism != null && this.saslMechanism.toUpperCase(Locale.ROOT).startsWith(SASL_SCRAM_SHA_MECHANISM_PREFIX);
}

/**
* Is the cluster coppering using Kraft Controller nodes.
*
Expand Down
40 changes: 31 additions & 9 deletions impl/src/main/java/io/kroxylicious/testing/kafka/common/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.NewPartitionReassignment;
import org.apache.kafka.clients.admin.NewTopic;
import org.apache.kafka.clients.admin.ScramCredentialInfo;
import org.apache.kafka.clients.admin.ScramMechanism;
import org.apache.kafka.clients.admin.TopicDescription;
import org.apache.kafka.clients.admin.UserScramCredentialAlteration;
import org.apache.kafka.clients.admin.UserScramCredentialUpsertion;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.TopicPartitionInfo;
import org.apache.kafka.common.errors.InvalidReplicationFactorException;
Expand All @@ -46,6 +50,7 @@
public class Utils {
private static final Logger log = getLogger(Utils.class);
private static final String CONSISTENCY_TEST = "__org_kroxylicious_testing_consistencyTest";
private static final int SCRAM_ITERATIONS = 4096;

private Utils() {
}
Expand All @@ -55,10 +60,10 @@ private Utils() {
* have at least one replica elsewhere in the cluster.
*
* @param connectionConfig the connection config
* @param fromNodeId nodeId being evacuated
* @param toNodeId replacement nodeId
* @param timeout the timeout
* @param timeUnit the time unit
* @param fromNodeId nodeId being evacuated
* @param toNodeId replacement nodeId
* @param timeout the timeout
* @param timeUnit the time unit
*/
public static void awaitReassignmentOfKafkaInternalTopicsIfNecessary(Map<String, Object> connectionConfig, int fromNodeId, int toNodeId, int timeout,
TimeUnit timeUnit) {
Expand Down Expand Up @@ -113,9 +118,9 @@ public static void awaitReassignmentOfKafkaInternalTopicsIfNecessary(Map<String,
* To Verify that all the expected brokers are in the cluster we create a topic with a replication factor = to the expected number of brokers.
* We then poll describeTopics until
*
* @param connectionConfig the connection config
* @param timeout the timeout
* @param timeUnit the time unit
* @param connectionConfig the connection config
* @param timeout the timeout
* @param timeUnit the time unit
* @param expectedBrokerCount the expected broker count
*/
public static void awaitExpectedBrokerCountInClusterViaTopic(Map<String, Object> connectionConfig, int timeout, TimeUnit timeUnit, Integer expectedBrokerCount) {
Expand Down Expand Up @@ -163,6 +168,22 @@ public static void awaitExpectedBrokerCountInClusterViaTopic(Map<String, Object>
}
}

public static void createUsersOnClusterIfNecessary(Map<String, Object> connectionConfig, KafkaClusterConfig clusterConfig) {
var users = clusterConfig.getUsers();
var saslMechanism = clusterConfig.getSaslMechanism();
if (users.isEmpty() || !clusterConfig.isSaslScram()) {
return;
}
var sci = new ScramCredentialInfo(ScramMechanism.fromMechanismName(saslMechanism), SCRAM_ITERATIONS);
try (var admin = Admin.create(connectionConfig)) {
// TODO fail gracefully if KRaft and Metaddata version does not yet support SCRAM
admin.alterUserScramCredentials(users.entrySet().stream()
.map(e -> new UserScramCredentialUpsertion(e.getKey(), sci, e.getValue()))
.map(UserScramCredentialAlteration.class::cast)
.toList()).all().toCompletionStage().toCompletableFuture().join();
}
}

/*
* There are edge cases where deleting the topic isn't possible. Primarily `delete.topic.enable==false`.
* Rather than attempt to detect that in advance (as that requires an RPC) we catch that and return normally
Expand Down Expand Up @@ -229,14 +250,15 @@ private static boolean isRetryable(Throwable potentiallyWrapped) {
var throwable = potentiallyWrapped instanceof CompletionException && potentiallyWrapped.getCause() != null ? potentiallyWrapped.getCause() : potentiallyWrapped;
boolean retriable = throwable instanceof RetriableException
&& (throwable.getMessage() == null
|| !throwable.getMessage().contains("The AdminClient is not accepting new calls") /* workaround for KAFKA-15507 */ );
|| !throwable.getMessage().contains("The AdminClient is not accepting new calls") /* workaround for KAFKA-15507 */);
return retriable || throwable instanceof InvalidReplicationFactorException
|| (throwable instanceof TopicExistsException && throwable.getMessage().contains("is marked for deletion"));
}

/**
* Factory for {@link Awaitility#await()} preconfigured with defaults.
* @param timeout at most timeout
*
* @param timeout at most timeout
* @param timeUnit at most {@link TimeUnit}
* @return preconfigured factory
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
import io.kroxylicious.testing.kafka.common.Utils;
import io.kroxylicious.testing.kafka.internal.AdminSource;

import static org.apache.kafka.server.common.MetadataVersion.MINIMUM_BOOTSTRAP_VERSION;

/**
* Configures and manages an in process (within the JVM) Kafka cluster.
*/
Expand Down Expand Up @@ -134,7 +132,7 @@ private static void prepareLogDirsForKraft(String clusterId, KafkaConfig config,
boolean.class);
// note ignoreFormatter=true so tolerate a log directory which is already formatted. this is
// required to support start/stop.
formatCommand.invoke(null, LOGGING_PRINT_STREAM, directories, metaProperties, MINIMUM_BOOTSTRAP_VERSION, true);
formatCommand.invoke(null, LOGGING_PRINT_STREAM, directories, metaProperties, MetadataVersion.latestProduction(), true);
}
catch (Exception e) {
throw new RuntimeException("failed to prepare log dirs for KRaft", e);
Expand Down Expand Up @@ -207,10 +205,14 @@ public synchronized void start() {
tryToStartServerWithRetry(configHolder, server);
servers.put(configHolder.getBrokerNum(), server);
});
var anonConnectConfigForCluster = clusterConfig.getAnonConnectConfigForCluster(buildBrokerList(nodeId -> getEndpointPair(Listener.ANON, nodeId)));
Utils.awaitExpectedBrokerCountInClusterViaTopic(
clusterConfig.getAnonConnectConfigForCluster(buildBrokerList(nodeId -> getEndpointPair(Listener.ANON, nodeId))), 120,
TimeUnit.SECONDS,
clusterConfig.getBrokersNum());

Utils.createUsersOnClusterIfNecessary(anonConnectConfigForCluster, clusterConfig);

}

private void tryToStartServerWithRetry(KafkaClusterConfig.ConfigHolder configHolder, Server server) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,14 @@ public synchronized void start() {
zookeeper.start();
}
Startables.deepStart(nodes.values().stream()).get(READY_TIMEOUT_SECONDS, TimeUnit.SECONDS);

var anonConnectConfigForCluster = clusterConfig.getAnonConnectConfigForCluster(buildBrokerList(nodeId -> getEndpointPair(Listener.ANON, nodeId)));
awaitExpectedBrokerCountInClusterViaTopic(
clusterConfig.getAnonConnectConfigForCluster(buildBrokerList(nodeId -> getEndpointPair(Listener.ANON, nodeId))),
anonConnectConfigForCluster,
READY_TIMEOUT_SECONDS, TimeUnit.SECONDS,
clusterConfig.getBrokersNum());

Utils.createUsersOnClusterIfNecessary(anonConnectConfigForCluster, clusterConfig);
}
catch (InterruptedException | ExecutionException | TimeoutException e) {
if (e instanceof InterruptedException) {
Expand Down
Loading

0 comments on commit 9bb2ca6

Please sign in to comment.