diff --git a/pom.xml b/pom.xml index ec5156f0..90164fb9 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ 4.0.0 io.github.brenoepics at4j - 0.0.6 + 0.0.7 Azure Translator For Java A simple Java library to translate text using Azure AI Cognitive Services. @@ -44,8 +44,8 @@ UTF-8 - 9 - 9 + 11 + 11 brenoepic https://sonarcloud.io @@ -83,18 +83,6 @@ - - com.squareup.okhttp3 - okhttp - 4.12.0 - compile - - - com.squareup.okhttp3 - logging-interceptor - 4.12.0 - compile - org.apache.logging.log4j log4j-api @@ -143,6 +131,8 @@ maven-compiler-plugin 3.12.1 + 11 + 11 diff --git a/src/main/java/io/github/brenoepics/at4j/AzureApiBuilder.java b/src/main/java/io/github/brenoepics/at4j/AzureApiBuilder.java index 39624b4f..74d23d6d 100644 --- a/src/main/java/io/github/brenoepics/at4j/AzureApiBuilder.java +++ b/src/main/java/io/github/brenoepics/at4j/AzureApiBuilder.java @@ -2,10 +2,13 @@ import io.github.brenoepics.at4j.azure.BaseURL; import io.github.brenoepics.at4j.core.AzureApiImpl; -import io.github.brenoepics.at4j.util.logging.LoggerUtil; import io.github.brenoepics.at4j.util.logging.PrivacyProtectionLogger; -import okhttp3.OkHttpClient; -import okhttp3.logging.HttpLoggingInterceptor; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.time.Duration; /** * Builder class for constructing instances of AzureApi. @@ -23,6 +26,14 @@ public class AzureApiBuilder { // The subscription region for the Azure API. private String subscriptionRegion; + private ProxySelector proxySelector = null; + + private SSLContext sslContext = null; + + private SSLParameters sslParameters = null; + + private Duration connectTimeout = null; + /** Default constructor initializes the base URL to the global endpoint. */ public AzureApiBuilder() { this.baseURL = BaseURL.GLOBAL; @@ -68,6 +79,62 @@ public AzureApiBuilder region(String subscriptionRegion) { return this; } + /** + * Sets the proxy selector for the Azure API. + * + * @param proxySelector The proxy selector for the Azure API. + * @return The current instance of AzureApiBuilder for method chaining. + * @see ProxySelector + */ + public AzureApiBuilder proxy(ProxySelector proxySelector) { + this.proxySelector = proxySelector; + return this; + } + + /** + * Sets the connect timeout for the Azure API. + * + * @param connectTimeout The connect timeout for the Azure API. + * @return The current instance of AzureApiBuilder for method chaining. + * @see Connection Timeout + */ + public AzureApiBuilder connectTimeout(Duration connectTimeout) { + this.connectTimeout = connectTimeout; + return this; + } + + /** + * Sets the SSL context for the Azure API. + * + * @param sslContext The SSL context for the Azure API. + * @return The current instance of AzureApiBuilder for method chaining. + * @see SSLContext + */ + public AzureApiBuilder sslContext(SSLContext sslContext) { + this.sslContext = sslContext; + return this; + } + + /** + * Sets the SSL parameters for the Azure API. + * + * @param sslParameters The SSL parameters for the Azure API. + * @return The current instance of AzureApiBuilder for method chaining. + * @see SSLParameters + */ + public AzureApiBuilder sslParameters(SSLParameters sslParameters) { + this.sslParameters = sslParameters; + return this; + } + /** * Builds and returns an instance of AzureApi with the configured parameters. * @@ -81,20 +148,24 @@ public AzureApi build() { } // The HTTP client used by the Azure API. - OkHttpClient httpClient = - new OkHttpClient.Builder() - .addInterceptor( - chain -> - chain.proceed( - chain - .request() - .newBuilder() - .addHeader("User-Agent", AT4J.USER_AGENT) - .build())) - .addInterceptor( - new HttpLoggingInterceptor(LoggerUtil.getLogger(OkHttpClient.class)::trace) - .setLevel(HttpLoggingInterceptor.Level.BODY)) - .build(); - return new AzureApiImpl<>(httpClient, baseURL, subscriptionKey, subscriptionRegion); + HttpClient.Builder httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1); + + if (proxySelector != null) { + httpClient.proxy(proxySelector); + } + + if (sslContext != null) { + httpClient.sslContext(sslContext); + } + + if (sslParameters != null) { + httpClient.sslParameters(sslParameters); + } + + if (connectTimeout != null) { + httpClient.connectTimeout(connectTimeout); + } + + return new AzureApiImpl<>(httpClient.build(), baseURL, subscriptionKey, subscriptionRegion); } } diff --git a/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java b/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java index f650d67e..72233e61 100644 --- a/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java +++ b/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java @@ -19,13 +19,12 @@ import io.github.brenoepics.at4j.util.rest.RestMethod; import io.github.brenoepics.at4j.util.rest.RestRequest; +import java.net.http.HttpClient; import java.util.ArrayList; import java.util.Collection; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import okhttp3.OkHttpClient; - /** * This class is an implementation of the AzureApi interface. It provides methods to interact with * Azure's translation API. @@ -33,7 +32,7 @@ public class AzureApiImpl implements AzureApi { /** The Http Client for this instance. */ - private final OkHttpClient httpClient; + private final HttpClient httpClient; /** The BaseURL for this instance. */ private final BaseURL baseURL; @@ -62,7 +61,7 @@ public class AzureApiImpl implements AzureApi { * @param subscriptionRegion The subscription region for this instance. */ public AzureApiImpl( - OkHttpClient httpClient, BaseURL baseURL, String subscriptionKey, String subscriptionRegion) { + HttpClient httpClient, BaseURL baseURL, String subscriptionKey, String subscriptionRegion) { this.httpClient = httpClient; this.baseURL = baseURL; this.subscriptionKey = subscriptionKey; @@ -183,15 +182,14 @@ public CompletableFuture>> getAvailableLanguages( @Override public void disconnect() { this.threadPool.getExecutorService().shutdown(); - this.httpClient.dispatcher().executorService().shutdown(); } /** * Gets the used OkHttpClient. * - * @return OkHttpClient - The used OkHttpClient. + * @return HttpClient - The used HttpClient. */ - public OkHttpClient getHttpClient() { + public HttpClient getHttpClient() { return this.httpClient; } diff --git a/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitBucket.java b/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitBucket.java index d4d13026..27aea50e 100644 --- a/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitBucket.java +++ b/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitBucket.java @@ -13,7 +13,6 @@ public class RateLimitBucket { private final ConcurrentLinkedQueue> requestQueue = new ConcurrentLinkedQueue<>(); private final RestEndpoint endpoint; - private final String majorUrlParameter; private volatile long rateLimitResetTimestamp = 0; private volatile int rateLimitRemaining = 1; @@ -22,11 +21,9 @@ public class RateLimitBucket { * Creates a RateLimitBucket for the given endpoint / parameter combination. * * @param endpoint The REST endpoint the rate-limit is tracked for. - * @param majorUrlParameter The url parameter this bucket is specific for. Maybe null. */ - public RateLimitBucket(RestEndpoint endpoint, String majorUrlParameter) { + public RateLimitBucket(RestEndpoint endpoint) { this.endpoint = endpoint; - this.majorUrlParameter = majorUrlParameter; } /** @@ -88,23 +85,17 @@ public int getTimeTillSpaceGetsAvailable() { * Checks if a bucket created with the given parameters would equal this bucket. * * @param endpoint The endpoint. - * @param majorUrlParameter The major url parameter. * @return Whether a bucket created with the given parameters would equal this bucket or not. */ - public boolean equals(RestEndpoint endpoint, String majorUrlParameter) { - boolean endpointSame = this.endpoint == endpoint; - boolean majorUrlParameterBothNull = this.majorUrlParameter == null && majorUrlParameter == null; - boolean majorUrlParameterEqual = - this.majorUrlParameter != null && this.majorUrlParameter.equals(majorUrlParameter); - - return endpointSame && (majorUrlParameterBothNull || majorUrlParameterEqual); + public boolean endpointMatches(RestEndpoint endpoint) { + return this.endpoint == endpoint; } @Override public boolean equals(Object obj) { if (obj instanceof RateLimitBucket) { RateLimitBucket otherBucket = (RateLimitBucket) obj; - return equals(otherBucket.endpoint, otherBucket.majorUrlParameter); + return endpointMatches(otherBucket.endpoint); } return false; } @@ -112,18 +103,13 @@ public boolean equals(Object obj) { @Override public int hashCode() { int hash = 42; - int urlParamHash = majorUrlParameter == null ? 0 : majorUrlParameter.hashCode(); int endpointHash = endpoint == null ? 0 : endpoint.hashCode(); - - hash = hash * 11 + urlParamHash; hash = hash * 17 + endpointHash; return hash; } @Override public String toString() { - String str = "Endpoint: " + (endpoint == null ? "global" : endpoint.getEndpointUrl()); - str += ", Major url parameter:" + (majorUrlParameter == null ? "none" : majorUrlParameter); - return str; + return "Endpoint: " + (endpoint == null ? "global" : endpoint.getEndpointUrl()); } } diff --git a/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManager.java b/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManager.java index e2f162e6..55f8ea9f 100644 --- a/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManager.java +++ b/src/main/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManager.java @@ -8,6 +8,8 @@ import io.github.brenoepics.at4j.util.rest.RestRequestResponseInformationImpl; import io.github.brenoepics.at4j.util.rest.RestRequestResult; +import java.net.http.HttpHeaders; +import java.net.http.HttpResponse; import java.util.HashSet; import java.util.Objects; import java.util.Optional; @@ -15,7 +17,6 @@ import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import okhttp3.Response; import org.apache.logging.log4j.Logger; /** This class manages rate-limits and keeps track of them. */ @@ -30,6 +31,18 @@ public class RateLimitManager { /** All buckets. */ private final Set> buckets = new HashSet<>(); + /** The header for rate-limit remaining information. */ + public static final String RATE_LIMITED_HEADER = "X-RateLimit-Remaining"; + + /** The header name for rate-limit reset information. */ + public static final String RATE_LIMIT_RESET_HEADER = "X-RateLimit-Reset"; + + /** The header name for rate-limit reset information. */ + public static final String RATE_LIMITED_HEADER_CLOUDFLARE = "Retry-after"; + + /** The body name for rate-limit reset information. */ + public static final String RATE_LIMITED_BODY_CLOUDFLARE = "retry_after"; + /** * Creates a new rate-limit manager. * @@ -48,27 +61,32 @@ public RateLimitManager(AzureApiImpl api) { public void queueRequest(RestRequest request) { Optional> searchBucket = searchBucket(request); - if (!searchBucket.isPresent()) { + if (searchBucket.isEmpty()) { return; } - final RateLimitBucket bucket = searchBucket.get(); - - api.getThreadPool() - .getExecutorService() - .submit( - () -> { - RestRequest currentRequest = bucket.peekRequestFromQueue(); - RestRequestResult result = null; - long responseTimestamp = System.currentTimeMillis(); - while (currentRequest != null) { - RestRequestHandler newResult = - handleCurrentRequest(result, currentRequest, bucket, responseTimestamp); - result = newResult.getResult(); - currentRequest = newResult.getCurrentRequest(); - responseTimestamp = newResult.getResponseTimestamp(); - } - }); + api.getThreadPool().getExecutorService().submit(() -> submitRequest(searchBucket.get())); + } + + /** + * Submits the request to the given bucket. + * + * @param bucket The bucket to submit the request to. + */ + private void submitRequest(RateLimitBucket bucket) { + RestRequest currentRequest = bucket.peekRequestFromQueue(); + RestRequestResult result = null; + + long responseTimestamp = System.currentTimeMillis(); + + while (currentRequest != null) { + RestRequestHandler newResult = + handleCurrentRequest(result, currentRequest, bucket, responseTimestamp); + + result = newResult.getResult(); + currentRequest = newResult.getCurrentRequest(); + responseTimestamp = newResult.getResponseTimestamp(); + } } /** @@ -85,17 +103,18 @@ RestRequestHandler handleCurrentRequest( RestRequest currentRequest, RateLimitBucket bucket, long responseTimestamp) { + try { waitUntilSpaceGetsAvailable(bucket); + + // Execute the request and get the result result = currentRequest.executeBlocking(); responseTimestamp = System.currentTimeMillis(); + } catch (Exception e) { responseTimestamp = System.currentTimeMillis(); if (currentRequest.getResult().isDone()) { - logger.warn( - "Received exception for a request that is already done. This should not be able to" - + " happen!", - e); + logger.warn("Exception for a already done request. This should not happen!", e); } if (e instanceof AzureException) { @@ -104,11 +123,8 @@ RestRequestHandler handleCurrentRequest( currentRequest.getResult().completeExceptionally(e); } finally { - try { - // Handle the response + if (result != null && result.getResponse() != null) { handleResponse(currentRequest, result, bucket, responseTimestamp); - } catch (Exception e) { - logger.warn("Encountered unexpected exception.", e); } // The request didn't finish, so let's try again @@ -162,6 +178,13 @@ RestRequest retryRequest(RateLimitBucket bucket) { } } + /** + * Maps the given exception to a {@link RestRequestResult}. + * + * @param t The exception to map. + * @return The mapped exception. + */ + @SuppressWarnings("unchecked") private RestRequestResult mapAzureException(Throwable t) { return ((AzureException) t) .getResponse() @@ -178,30 +201,34 @@ private RestRequestResult mapAzureException(Throwable t) { */ Optional> searchBucket(RestRequest request) { synchronized (buckets) { - RateLimitBucket bucket = - buckets.stream() - .filter( - b -> b.equals(request.getEndpoint(), request.getMajorUrlParameter().orElse(null))) - .findAny() - .orElseGet( - () -> - new RateLimitBucket<>( - request.getEndpoint(), request.getMajorUrlParameter().orElse(null))); + RateLimitBucket bucket = getMatchingBucket(request); // Check if it is already in the queue, send not present if (bucket.peekRequestFromQueue() != null) { return Optional.empty(); } - // Add the bucket to the set of buckets (does nothing if it's already in the set) buckets.add(bucket); - - // Add the request to the bucket's queue bucket.addRequestToQueue(request); return Optional.of(bucket); } } + /** + * Gets the bucket that matches the given request. + * + * @param request The request. + * @return The bucket that matches the request. + */ + private RateLimitBucket getMatchingBucket(RestRequest request) { + synchronized (buckets) { + return buckets.stream() + .filter(b -> b.endpointMatches(request.getEndpoint())) + .findAny() + .orElseGet(() -> new RateLimitBucket<>(request.getEndpoint())); + } + } + /** * Updates the rate-limit information and sets the result if the request was successful. * @@ -210,56 +237,94 @@ Optional> searchBucket(RestRequest request) { * @param bucket The bucket the request belongs to. * @param responseTimestamp The timestamp directly after the response finished. */ - void handleResponse( + private void handleResponse( RestRequest request, RestRequestResult result, RateLimitBucket bucket, long responseTimestamp) { - if (result == null || result.getResponse() == null) { - return; - } + try { + HttpResponse response = result.getResponse(); - Response response = result.getResponse(); - int remaining = - Integer.parseInt(Objects.requireNonNull(response.header("X-RateLimit-Remaining", "1"))); - long reset = - (long) - (Double.parseDouble(Objects.requireNonNull(response.header("X-RateLimit-Reset", "0"))) - * 1000); - - // Check if we received a 429 response - if (result.getResponse().code() != 429) { - // Check if we didn't already complete it exceptionally. - CompletableFuture> requestResult = request.getResult(); - if (!requestResult.isDone()) { - requestResult.complete(result); + // Check if we did not receive a rate-limit response + if (result.getResponse().statusCode() != 429) { + handleRateLimit(request.getResult(), result, bucket, response.headers()); + return; } - // Update bucket information - bucket.setRateLimitRemaining(remaining); - bucket.setRateLimitResetTimestamp(reset); - return; - } + if (response.headers().firstValue("Via").isEmpty()) { + handleCloudFlare(response.headers(), bucket); + return; + } + + long retryAfter = 0; + + if (!result.getJsonBody().isNull()) { + retryAfter = + (long) (result.getJsonBody().get(RATE_LIMITED_BODY_CLOUDFLARE).asDouble() * 1000); + } + + logger.debug("Received a 429 response from Azure! Recalculating time offset..."); - if (response.header("Via") == null) { - logger.warn( - "Hit a CloudFlare API ban! This means you were sending a very large amount of invalid" - + " requests."); - int retryAfter = - Integer.parseInt(Objects.requireNonNull(response.header("Retry-after"))) * 1000; - bucket.setRateLimitRemaining(retryAfter); + bucket.setRateLimitRemaining(0); bucket.setRateLimitResetTimestamp(responseTimestamp + retryAfter); - return; + + } catch (Exception e) { + logger.warn("Encountered unexpected exception.", e); + } + } + + /** + * Handles the CloudFlare rate-limit. + * + * @param headers The headers of the response. + * @param bucket The bucket the request belongs to. + */ + private void handleCloudFlare(HttpHeaders headers, RateLimitBucket bucket) { + logger.warn( + "Hit a CloudFlare API ban! {}", + "You were sending a very large amount of invalid requests."); + int retryAfter = + Integer.parseInt(getHeader(headers, RATE_LIMITED_HEADER_CLOUDFLARE, "10")) * 1000; + bucket.setRateLimitRemaining(retryAfter); + bucket.setRateLimitResetTimestamp(System.currentTimeMillis() + retryAfter); + } + + /** + * Handles the rate-limit information. + * + * @param request The request. + * @param result The result of the request. + * @param bucket The bucket the request belongs to. + * @param headers The headers of the response. + */ + private void handleRateLimit( + CompletableFuture> request, + RestRequestResult result, + RateLimitBucket bucket, + HttpHeaders headers) { + + // Check if we didn't already complete it exceptionally. + if (!request.isDone()) { + request.complete(result); } - long retryAfter = - result.getJsonBody().isNull() - ? 0 - : (long) (result.getJsonBody().get("retry_after").asDouble() * 1000); - logger.debug("Received a 429 response from Azure! Recalculating time offset..."); + String remaining = getHeader(headers, RATE_LIMITED_HEADER, "1"); + String reset = getHeader(headers, RATE_LIMIT_RESET_HEADER, "0"); + + // Update bucket information + bucket.setRateLimitRemaining(Integer.parseInt(remaining)); + bucket.setRateLimitResetTimestamp((long) (Double.parseDouble(reset) * 1000)); + } - // Update the bucket information - bucket.setRateLimitRemaining(0); - bucket.setRateLimitResetTimestamp(responseTimestamp + retryAfter); + /** + * Gets the header value from the given headers. + * + * @param headers The headers. + * @param header The header to get the value for. + * @param defaultValue The default value if the header is not present. + * @return The header value. + */ + public static String getHeader(HttpHeaders headers, String header, String defaultValue) { + return Objects.requireNonNull(headers.firstValue(header).orElse(defaultValue)); } } diff --git a/src/main/java/io/github/brenoepics/at4j/core/thread/AT4JThreadFactory.java b/src/main/java/io/github/brenoepics/at4j/core/thread/AT4JThreadFactory.java index 77fd3e21..b8979b60 100644 --- a/src/main/java/io/github/brenoepics/at4j/core/thread/AT4JThreadFactory.java +++ b/src/main/java/io/github/brenoepics/at4j/core/thread/AT4JThreadFactory.java @@ -1,7 +1,6 @@ package io.github.brenoepics.at4j.core.thread; import java.util.concurrent.atomic.AtomicInteger; -import org.jetbrains.annotations.NotNull; /** A thread factory that creates optionally numbered threads as daemon or non-daemon. */ public class AT4JThreadFactory implements java.util.concurrent.ThreadFactory { @@ -28,7 +27,7 @@ public AT4JThreadFactory(String namePattern, boolean daemon) { } @Override - public Thread newThread(@NotNull Runnable r) { + public Thread newThread(Runnable r) { Thread thread = new Thread(r, String.format(namePattern, counter.incrementAndGet())); thread.setDaemon(daemon); return thread; diff --git a/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java b/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java index 1b78e562..cae7c8d8 100644 --- a/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java +++ b/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java @@ -47,8 +47,6 @@ public TranslateParams(String text, Collection targetLanguages) { this.targetLanguages = targetLanguages; } - // Setter methods for the class fields - /** * Sets the text to be translated. * diff --git a/src/main/java/io/github/brenoepics/at4j/util/rest/RestEndpoint.java b/src/main/java/io/github/brenoepics/at4j/util/rest/RestEndpoint.java index 3edcb724..ceb92476 100644 --- a/src/main/java/io/github/brenoepics/at4j/util/rest/RestEndpoint.java +++ b/src/main/java/io/github/brenoepics/at4j/util/rest/RestEndpoint.java @@ -1,9 +1,13 @@ package io.github.brenoepics.at4j.util.rest; import io.github.brenoepics.at4j.azure.BaseURL; -import java.util.Optional; -import okhttp3.HttpUrl; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; /** * This enum contains all endpoints that we may use. Each endpoint is represented as an enum @@ -70,36 +74,26 @@ public String getEndpointUrl() { } /** - * Gets the full url of the endpoint. Parameters which are "too much" are added to the end. + * Gets the full {@link URI http url} of the endpoint. Parameters which are "too much" are added + * to the end. * * @param baseURL The base url of the endpoint. - * @param parameters The parameters of the url. E.g., for channel ids. - * @return The full url of the endpoint. + * @return The full http url of the endpoint. */ - public String getFullUrl(BaseURL baseURL, String... parameters) { - StringBuilder url = new StringBuilder("https://" + baseURL.getUrl() + getEndpointUrl()); + public URI getHttpUrl(BaseURL baseURL, Map> queryParams) + throws URISyntaxException { + String query = getQuery(queryParams); - url = new StringBuilder(String.format(url.toString(), (Object[]) parameters)); - int parameterAmount = - getEndpointUrl().split("%s").length - (getEndpointUrl().endsWith("%s") ? 0 : 1); - - if (parameters.length > parameterAmount) { - for (int i = parameterAmount; i < parameters.length; i++) { - url.append("/").append(parameters[i]); - } - } - return url.toString(); + return new URI("https", baseURL.getUrl(), endpointUrl, query, null); } - /** - * Gets the full {@link HttpUrl http url} of the endpoint. Parameters which are "too much" are - * added to the end. - * - * @param baseURL The base url of the endpoint. - * @param parameters The parameters of the url. E.g., for channel ids. - * @return The full http url of the endpoint. - */ - public HttpUrl getOkHttpUrl(BaseURL baseURL, String... parameters) { - return HttpUrl.parse(getFullUrl(baseURL, parameters)); + private String getQuery(Map> queryParams) { + return queryParams.entrySet().stream() + .map( + entry -> + entry.getValue().stream() + .map(value -> entry.getKey() + "=" + value) + .collect(Collectors.joining("&"))) + .collect(Collectors.joining("&")); } } diff --git a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequest.java b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequest.java index 42d8cedf..7e2088a0 100644 --- a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequest.java +++ b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequest.java @@ -8,13 +8,14 @@ import io.github.brenoepics.at4j.util.logging.LoggerUtil; import java.io.IOException; import java.net.MalformedURLException; -import java.net.URL; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.util.*; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Function; -import okhttp3.*; import org.apache.logging.log4j.Logger; /** This class is used to wrap a rest request. */ @@ -28,19 +29,12 @@ public class RestRequest { private final RestEndpoint endpoint; private volatile boolean includeAuthorizationHeader = true; - private AtomicReferenceArray urlParameters = new AtomicReferenceArray<>(new String[0]); private final Map> queryParameters = new HashMap<>(); private final Map headers = new HashMap<>(); private volatile String body = null; private final CompletableFuture> result = new CompletableFuture<>(); - /** The multipart body of the request. */ - private MultipartBody multipartBody; - - /** The custom major parameter if it's not included in the url (e.g., for reactions) */ - private String customMajorParam = null; - /** The origin of the rest request. */ private final Exception origin; @@ -96,19 +90,6 @@ public Map> getQueryParameters() { return queryParameters; } - /** - * Gets an array with all used url parameters. - * - * @return An array with all used url parameters. - */ - public String[] getUrlParameters() { - String[] parameters = new String[urlParameters.length()]; - for (int i = 0; i < urlParameters.length(); i++) { - parameters[i] = urlParameters.get(i); - } - return parameters; - } - /** * Gets the body of this request. * @@ -118,29 +99,6 @@ public Optional getBody() { return Optional.ofNullable(body); } - /** - * Gets the major url parameter of this request. If a request has a major parameter, it means that - * the rate-limits for this request are based on this parameter. - * - * @return The major url parameter used for rate-limits. - */ - public Optional getMajorUrlParameter() { - if (customMajorParam != null) { - return Optional.of(customMajorParam); - } - - Optional majorParameterPosition = endpoint.getMajorParameterPosition(); - if (!majorParameterPosition.isPresent()) { - return Optional.empty(); - } - - if (majorParameterPosition.get() >= urlParameters.length()) { - return Optional.empty(); - } - - return Optional.of(urlParameters.get(majorParameterPosition.get())); - } - /** * Gets the origin of the rest request. * @@ -178,40 +136,6 @@ public Map getHeaders() { return headers; } - /** - * Sets the url parameters, e.g., a language parameter. - * - * @param parameters The parameters. - * @return The current instance to chain call methods. - */ - public RestRequest setUrlParameters(String... parameters) { - this.urlParameters = new AtomicReferenceArray<>(parameters); - return this; - } - - /** - * Sets the multipart body of the request. If a multipart body is set, the {@link - * #setBody(String)} method is ignored! - * - * @param multipartBody The multipart body of the request. - * @return The current instance to chain call methods. - */ - public RestRequest setMultipartBody(MultipartBody multipartBody) { - this.multipartBody = multipartBody; - return this; - } - - /** - * Sets a custom major parameter. - * - * @param customMajorParam The custom parameter to set. - * @return The current instance to chain call methods. - */ - public RestRequest setCustomMajorParam(String customMajorParam) { - this.customMajorParam = customMajorParam; - return this; - } - /** * Sets the body of the request. * @@ -281,20 +205,18 @@ public CompletableFuture> getResult() { * Gets the information for this rest request. * * @return The information for this rest request. + * @throws AssertionError Thrown if the url is malformed. */ public RestRequestInformation asRestRequestInformation() { try { - String[] parameters = new String[urlParameters.length()]; - for (int i = 0; i < urlParameters.length(); i++) { - parameters[i] = urlParameters.get(i); - } + return new RestRequestInformationImpl( api, - new URL(endpoint.getFullUrl(api.getBaseURL(), parameters)), + endpoint.getHttpUrl(api.getBaseURL(), queryParameters).toURL(), queryParameters, headers, body); - } catch (MalformedURLException e) { + } catch (URISyntaxException | MalformedURLException e) { throw new AssertionError(e); } } @@ -304,52 +226,53 @@ public RestRequestInformation asRestRequestInformation() { * * @return The result of the request. * @throws AzureException Thrown in case of an error while requesting azure. - * @throws IOException Thrown if OkHttp {@link OkHttpClient#newCall(Request)} throws an {@link - * IOException}. + * @throws IOException Thrown if an error occurs while reading the response. */ - public RestRequestResult executeBlocking() throws AzureException, IOException { - Request.Builder requestBuilder = new Request.Builder(); - String[] parameters = getUrlParameters(); - HttpUrl.Builder httpUrlBuilder = - endpoint.getOkHttpUrl(api.getBaseURL(), parameters).newBuilder(); - queryParameters.forEach( - (key, values) -> values.forEach(value -> httpUrlBuilder.addQueryParameter(key, value))); - requestBuilder.url(httpUrlBuilder.build()); + public RestRequestResult executeBlocking() + throws AzureException, IOException, URISyntaxException { + URI fullUrl = endpoint.getHttpUrl(api.getBaseURL(), queryParameters); + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder().uri(fullUrl); request(requestBuilder); if (includeAuthorizationHeader) { - requestBuilder.addHeader("Ocp-Apim-Subscription-Key", api.getSubscriptionKey()); + requestBuilder.setHeader("Ocp-Apim-Subscription-Key", api.getSubscriptionKey()); api.getSubscriptionRegion() - .ifPresent(region -> requestBuilder.addHeader("Ocp-Apim-Subscription-Region", region)); + .ifPresent(region -> requestBuilder.setHeader("Ocp-Apim-Subscription-Region", region)); } - String fullUrl = endpoint.getFullUrl(api.getBaseURL(), parameters); - headers.forEach(requestBuilder::addHeader); + headers.forEach(requestBuilder::setHeader); + logger.debug( "Trying to send {} request to {}{}", method.name(), - fullUrl, + requestBuilder, body != null ? " with body " + body : ""); - try (Response response = getApi().getHttpClient().newCall(requestBuilder.build()).execute()) { - RestRequestResult requestResult = new RestRequestResult<>(this, response); - - String bodyPresent = requestResult.getBody().map(b -> "").orElse("empty"); - String bodyString = requestResult.getStringBody().orElse(""); - logger.debug( - "Sent {} request to {} and received status code {} with {} body {}", - method.name(), - fullUrl, - response.code(), - bodyPresent, - bodyString); + CompletableFuture> response = + getApi() + .getHttpClient() + .sendAsync(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()); + RestRequestResult responseResult = handleResponse(fullUrl, response.join()); + result.complete(responseResult); + return responseResult; + } - if (response.code() >= 300 || response.code() < 200) { - return handleError(response.code(), requestResult); - } + private RestRequestResult handleResponse(URI fullUrl, HttpResponse response) + throws IOException, AzureException { + RestRequestResult requestResult = new RestRequestResult<>(this, response); + String bodyString = requestResult.getStringBody().orElse("empty"); + logger.debug( + "Sent {} request to {} and received status code {} with body {}", + method.name(), + fullUrl.toURL(), + response.statusCode(), + bodyString); - return requestResult; + if (response.statusCode() >= 300 || response.statusCode() < 200) { + return handleError(response.statusCode(), requestResult); } + + return requestResult; } private RestRequestResult handleError(int resultCode, RestRequestResult result) @@ -376,14 +299,14 @@ private RestRequestResult handleError(int resultCode, RestRequestResult re "Received a " + resultCode + " response from Azure with" - + (result.getBody().isPresent() ? "" : " empty") + + (result.getStringBody().isPresent() ? "" : " empty") + " body" + result.getStringBody().map(s -> " " + s).orElse("") + "!", requestInformation, responseInformation)); - String bodyPresent = result.getBody().isPresent() ? "" : " empty"; + String bodyPresent = result.getStringBody().isPresent() ? "" : " empty"; throw azureException.isPresent() ? azureException.get() : new AzureException( @@ -423,38 +346,14 @@ private void handleKnownError( } } - private RequestBody createMultipartBody() { - if (multipartBody != null) { - return multipartBody; - } - - if (body != null) { - return RequestBody.create(body, MediaType.parse("application/json")); - } - return RequestBody.create(new byte[0], null); - } - - private void request(okhttp3.Request.Builder requestBuilder) { - RequestBody requestBody = createMultipartBody(); - - switch (method) { - case GET: - requestBuilder.get(); - break; - case POST: - requestBuilder.post(requestBody); - break; - case PUT: - requestBuilder.put(requestBody); - break; - case DELETE: - requestBuilder.delete(requestBody); - break; - case PATCH: - requestBuilder.patch(requestBody); - break; - default: - throw new IllegalArgumentException("Unsupported http method!"); - } + private void request(HttpRequest.Builder requestBuilder) { + requestBuilder.setHeader("User-Agent", AT4J.USER_AGENT); + requestBuilder.setHeader("Accept", "application/json"); + requestBuilder.setHeader("Content-Type", "application/json"); + HttpRequest.BodyPublisher bodyPublisher = + body == null + ? HttpRequest.BodyPublishers.noBody() + : HttpRequest.BodyPublishers.ofString(body); + requestBuilder.method(method.name(), bodyPublisher); } } diff --git a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResponseInformationImpl.java b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResponseInformationImpl.java index 5f331d37..a6a86dd7 100644 --- a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResponseInformationImpl.java +++ b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResponseInformationImpl.java @@ -42,7 +42,7 @@ public RestRequestInformation getRequest() { @Override public int getCode() { - return restRequestResult.getResponse().code(); + return restRequestResult.getResponse().statusCode(); } @Override diff --git a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResult.java b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResult.java index 1d5b9d01..51f1b0d4 100644 --- a/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResult.java +++ b/src/main/java/io/github/brenoepics/at4j/util/rest/RestRequestResult.java @@ -6,9 +6,8 @@ import com.fasterxml.jackson.databind.node.NullNode; import io.github.brenoepics.at4j.util.logging.LoggerUtil; import java.io.IOException; +import java.net.http.HttpResponse; import java.util.Optional; -import okhttp3.Response; -import okhttp3.ResponseBody; import org.apache.logging.log4j.Logger; /** The result of a {@link RestRequest}. */ @@ -18,8 +17,7 @@ public class RestRequestResult { private static final Logger logger = LoggerUtil.getLogger(RestRequestResult.class); private final RestRequest request; - private final Response response; - private final ResponseBody body; + private final HttpResponse response; private final String stringBody; private final JsonNode jsonBody; @@ -28,28 +26,28 @@ public class RestRequestResult { * * @param request The request of the result. * @param response The response of the RestRequest. - * @throws IOException Passed on from {@link ResponseBody#string()}. + * @throws IOException Passed on from {@link HttpResponse#body()}. */ - public RestRequestResult(RestRequest request, Response response) throws IOException { + public RestRequestResult(RestRequest request, HttpResponse response) + throws IOException { this.request = request; this.response = response; - this.body = response.body(); - if (body == null) { - stringBody = null; + this.stringBody = response.body(); + if (stringBody == null) { jsonBody = NullNode.getInstance(); - } else { - stringBody = body.string(); - ObjectMapper mapper = request.getApi().getObjectMapper(); - JsonNode jsonNode; - try { - jsonNode = mapper.readTree(stringBody); - } catch (JsonParseException e) { - // This can happen if Azure sends garbage - logger.debug("Failed to parse json response", e); - jsonNode = null; - } - this.jsonBody = jsonNode == null ? NullNode.getInstance() : jsonNode; + return; } + + ObjectMapper mapper = request.getApi().getObjectMapper(); + JsonNode jsonNode; + try { + jsonNode = mapper.readTree(stringBody); + } catch (JsonParseException e) { + // This can happen if Azure sends garbage + logger.debug("Failed to parse json response", e); + jsonNode = null; + } + this.jsonBody = jsonNode == null ? NullNode.getInstance() : jsonNode; } /** @@ -66,19 +64,10 @@ public RestRequest getRequest() { * * @return The response of the RestRequest. */ - public Response getResponse() { + public HttpResponse getResponse() { return response; } - /** - * Gets the body of the response. - * - * @return The body of the response. - */ - public Optional getBody() { - return Optional.ofNullable(body); - } - /** * Gets the string body of the response. * diff --git a/src/test/java/io/github/brenoepics/at4j/AzureApiBuilderTest.java b/src/test/java/io/github/brenoepics/at4j/AzureApiBuilderTest.java index 57220a09..ae91601c 100644 --- a/src/test/java/io/github/brenoepics/at4j/AzureApiBuilderTest.java +++ b/src/test/java/io/github/brenoepics/at4j/AzureApiBuilderTest.java @@ -3,6 +3,7 @@ import org.junit.jupiter.api.Test; import io.github.brenoepics.at4j.azure.BaseURL; + import static org.junit.jupiter.api.Assertions.*; class AzureApiBuilderTest { diff --git a/src/test/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManagerTest.java b/src/test/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManagerTest.java deleted file mode 100644 index 34c22ac2..00000000 --- a/src/test/java/io/github/brenoepics/at4j/core/ratelimit/RateLimitManagerTest.java +++ /dev/null @@ -1,156 +0,0 @@ -package io.github.brenoepics.at4j.core.ratelimit; - -import static org.junit.jupiter.api.Assertions.*; - -import io.github.brenoepics.at4j.core.AzureApiImpl; -import io.github.brenoepics.at4j.core.exceptions.AzureException; -import io.github.brenoepics.at4j.core.thread.ThreadPool; -import io.github.brenoepics.at4j.util.rest.RestRequest; -import io.github.brenoepics.at4j.util.rest.RestRequestResult; -import okhttp3.Response; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.function.Executable; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import java.io.IOException; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; - -import static org.mockito.Mockito.*; - -class RateLimitManagerTest { - - @Mock private AzureApiImpl api; - @Mock private RestRequest request; - @Mock private RestRequestResult result; - @Mock private Response response; - @Mock private RateLimitBucket bucket; - - private RateLimitManager rateLimitManager; - - @Mock private ThreadPool threadPool; - @Mock private ExecutorService executorService; - - @BeforeEach - public void setUp() { - MockitoAnnotations.openMocks(this); - rateLimitManager = new RateLimitManager<>(api); - - when(api.getThreadPool()).thenReturn(threadPool); - when(threadPool.getExecutorService()).thenReturn(executorService); - - when(response.header("X-RateLimit-Remaining", "1")).thenReturn("10"); - when(response.header("X-RateLimit-Reset", "0")).thenReturn("1000"); - } - - @Test - void queuesRequestWhenBucketIsPresent() { - when(request.getMajorUrlParameter()).thenReturn(Optional.empty()); - when(bucket.peekRequestFromQueue()).thenReturn(null); - - rateLimitManager.queueRequest(request); - - verify(executorService, times(1)).submit(any(Runnable.class)); - } - - @Test - void retriesRequestWhenBucketIsNotEmpty() { - when(bucket.peekRequestFromQueue()).thenReturn(request); - - RestRequest retriedRequest = rateLimitManager.retryRequest(bucket); - - assertEquals(request, retriedRequest); - } - - @Test - void doesNotRetryRequestWhenBucketIsEmpty() { - when(bucket.peekRequestFromQueue()).thenReturn(null); - - RestRequest retriedRequest = rateLimitManager.retryRequest(bucket); - - assertNull(retriedRequest); - } - - @Test - void doesNotHandleResponseWhenResultIsNull() { - assertDoesNotThrow( - () -> rateLimitManager.handleResponse(request, null, bucket, System.currentTimeMillis())); - } - - @Test - void queuesRequestWhenBucketIsNotPresent() { - when(request.getMajorUrlParameter()).thenReturn(Optional.empty()); - when(bucket.peekRequestFromQueue()).thenReturn(request); - - rateLimitManager.queueRequest(request); - assertDoesNotThrow( - () -> rateLimitManager.handleResponse(request, result, bucket, System.currentTimeMillis())); - } - - @Test - void handleResponseUpdatesBucketInformationWhenResponseCodeIsNot429() { - when(response.header("X-RateLimit-Remaining", "1")).thenReturn("10"); - when(response.header("X-RateLimit-Reset", "0")).thenReturn("1000"); - when(result.getResponse()).thenReturn(response); - when(response.code()).thenReturn(200); - when(request.getResult()).thenReturn(CompletableFuture.completedFuture(result)); - - rateLimitManager.handleResponse(request, result, bucket, System.currentTimeMillis()); - - verify(bucket, times(1)).setRateLimitRemaining(10); - verify(bucket, times(1)).setRateLimitResetTimestamp(anyLong()); - } - - @Test - void handleResponseDoesNotUpdateBucketInformationWhenResponseCodeIs429AndViaHeaderIsNull() { - when(result.getResponse()).thenReturn(response); - when(response.code()).thenReturn(429); - when(response.header("Via")).thenReturn(null); - when(response.header("Retry-after")).thenReturn("1000"); - when(response.header("X-RateLimit-Remaining", "1")).thenReturn("10"); - when(response.header("X-RateLimit-Reset", "0")).thenReturn("1000"); - - assertDoesNotThrow( - () -> rateLimitManager.handleResponse(request, result, bucket, System.currentTimeMillis())); - } - - @Test - void handleCurrentRequestThrowsException() throws AzureException, IOException { - // Arrange - RuntimeException expectedException = new RuntimeException(); - when(request.executeBlocking()).thenThrow(expectedException); - - // Act - Executable executable = - () -> - rateLimitManager.handleCurrentRequest( - result, request, bucket, System.currentTimeMillis()); - - // Assert - assertThrows(RuntimeException.class, executable); - } - - @Test - void - handleResponseDoesNotUpdateBucketInformationWhenResponseCodeIsNot429AndRequestResultIsDone() { - when(result.getResponse()).thenReturn(response); - when(response.code()).thenReturn(200); - when(request.getResult()).thenReturn(CompletableFuture.completedFuture(result)); - - assertDoesNotThrow( - () -> rateLimitManager.handleResponse(request, result, bucket, System.currentTimeMillis())); - } - - @Test - void searchBucketReturnsBucketWhenBucketIsPresentAndRequestQueueIsEmpty() { - when(request.getMajorUrlParameter()).thenReturn(Optional.empty()); - when(bucket.peekRequestFromQueue()).thenReturn(null); - - Optional> searchBucket = rateLimitManager.searchBucket(request); - - assertTrue(searchBucket.isPresent()); - } -} diff --git a/src/test/java/io/github/brenoepics/at4j/util/rest/RestRequestTest.java b/src/test/java/io/github/brenoepics/at4j/util/rest/RestRequestTest.java index 7d170ca3..e8615f75 100644 --- a/src/test/java/io/github/brenoepics/at4j/util/rest/RestRequestTest.java +++ b/src/test/java/io/github/brenoepics/at4j/util/rest/RestRequestTest.java @@ -43,15 +43,6 @@ void shouldAddHeader() { Assertions.assertEquals("headerValue", restRequest.getHeaders().get("headerName")); } - @Test - @DisplayName("Should set url parameters correctly") - void shouldSetUrlParameters() { - restRequest.setUrlParameters("param1", "param2"); - Assertions.assertEquals(2, restRequest.getUrlParameters().length); - Assertions.assertEquals("param1", restRequest.getUrlParameters()[0]); - Assertions.assertEquals("param2", restRequest.getUrlParameters()[1]); - } - @Test @DisplayName("Should set body correctly") void shouldSetBody() {