diff --git a/src/main/java/io/github/brenoepics/at4j/AzureApi.java b/src/main/java/io/github/brenoepics/at4j/AzureApi.java index c0b81a3a..ba36c99f 100644 --- a/src/main/java/io/github/brenoepics/at4j/AzureApi.java +++ b/src/main/java/io/github/brenoepics/at4j/AzureApi.java @@ -10,6 +10,7 @@ import io.github.brenoepics.at4j.data.response.TranslationResponse; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -72,7 +73,7 @@ public interface AzureApi { * @param params The {@link TranslateParams} to translate. * @return The {@link TranslationResponse} containing the translation. */ - CompletableFuture> translate(TranslateParams params); + CompletableFuture>> translate(TranslateParams params); /** * Gets the available languages for translation. diff --git a/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java b/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java index d6239050..6fcdd7d3 100644 --- a/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java +++ b/src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java @@ -18,10 +18,12 @@ import io.github.brenoepics.at4j.util.rest.RestEndpoint; import io.github.brenoepics.at4j.util.rest.RestMethod; import io.github.brenoepics.at4j.util.rest.RestRequest; +import io.github.brenoepics.at4j.util.rest.RestRequestResult; import java.net.http.HttpClient; import java.util.ArrayList; import java.util.Collection; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -89,42 +91,19 @@ public ThreadPool getThreadPool() { } @Override - public CompletableFuture> translate(TranslateParams params) { - if (params.getText() == null || params.getText().isEmpty()) { + public CompletableFuture>> translate(TranslateParams params) { + if (params.getTexts() == null || params.getTexts().isEmpty()) { return CompletableFuture.completedFuture(Optional.empty()); } - RestRequest> request = - new RestRequest>( + RestRequest>> request = + new RestRequest>>( this, RestMethod.POST, RestEndpoint.TRANSLATE) .setBody(params.getBody()); params.getQueryParameters().forEach(request::addQueryParameter); params.getTargetLanguages().forEach(lang -> request.addQueryParameter("to", lang)); - return request.execute( - response -> { - if (response.getJsonBody().isNull() - || !response.getJsonBody().has(0) - || !response.getJsonBody().get(0).has("translations")) return Optional.empty(); - - JsonNode jsonNode = response.getJsonBody().get(0); - Collection translations = new ArrayList<>(); - jsonNode - .get("translations") - .forEach(node -> translations.add(Translation.ofJSON((ObjectNode) node))); - - TranslationResponse translationResponse; - if (jsonNode.has("detectedLanguage")) { - JsonNode detectedLanguage = jsonNode.get("detectedLanguage"); - translationResponse = - new TranslationResponse( - DetectedLanguage.ofJSON((ObjectNode) detectedLanguage), translations); - } else { - translationResponse = new TranslationResponse(translations); - } - - return Optional.of(translationResponse); - }); + return request.execute(params::handleTranslations); } @Override diff --git a/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java b/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java index cae7c8d8..535d7c2a 100644 --- a/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java +++ b/src/main/java/io/github/brenoepics/at4j/data/request/TranslateParams.java @@ -5,9 +5,14 @@ import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import io.github.brenoepics.at4j.azure.lang.Language; +import io.github.brenoepics.at4j.data.DetectedLanguage; +import io.github.brenoepics.at4j.data.Translation; import io.github.brenoepics.at4j.data.request.optional.ProfanityAction; import io.github.brenoepics.at4j.data.request.optional.ProfanityMarker; import io.github.brenoepics.at4j.data.request.optional.TextType; +import io.github.brenoepics.at4j.data.response.TranslationResponse; +import io.github.brenoepics.at4j.util.rest.RestRequestResult; + import java.util.*; import java.util.stream.Collectors; @@ -18,7 +23,7 @@ */ public class TranslateParams { // The text to be translated - private String text; + private LinkedHashMap toTranslate; // The type of the text to be translated (plain or HTML) private TextType textType; // The action to be taken on profanities in the text @@ -43,18 +48,32 @@ public class TranslateParams { * @param targetLanguages The target languages for the translation. */ public TranslateParams(String text, Collection targetLanguages) { - this.text = text; + this.toTranslate = new LinkedHashMap<>(); + this.toTranslate.put(1, text); + this.targetLanguages = targetLanguages; + } + + /** + * Constructor that initializes the text to be translated. + * + * @param texts The text list to be translated. + * @param targetLanguages The target languages for the translation. + */ + public TranslateParams(Collection texts, Collection targetLanguages) { + this.toTranslate = new LinkedHashMap<>(); + texts.forEach(t -> this.toTranslate.put(this.toTranslate.size() + 1, t)); this.targetLanguages = targetLanguages; } /** * Sets the text to be translated. * - * @param text The text to be translated. + * @param texts The texts to be translated. * @return This instance. */ - public TranslateParams setText(String text) { - this.text = text; + public TranslateParams setTexts(Collection texts) { + this.toTranslate = new LinkedHashMap<>(); + texts.forEach(t -> this.toTranslate.put(this.toTranslate.size() + 1, t)); return this; } @@ -167,10 +186,7 @@ public TranslateParams setSourceLanguage(String sourceLanguage) { */ public TranslateParams setTargetLanguages(Collection targetLanguages) { this.targetLanguages = - Collections.unmodifiableCollection( - targetLanguages.stream() - .map(Language::getCode) - .collect(Collectors.toCollection(ArrayList::new))); + targetLanguages.stream().map(Language::getCode).collect(Collectors.toUnmodifiableList()); return this; } @@ -185,8 +201,8 @@ public TranslateParams setTargetLanguages(String... targetLanguages) { return this; } - public String getText() { - return text; + public Map getTexts() { + return toTranslate; } public Boolean getIncludeAlignment() { @@ -264,11 +280,45 @@ public Map getQueryParameters() { */ public JsonNode getBody() { ArrayNode body = JsonNodeFactory.instance.arrayNode(); - if (getText() != null && !getText().isEmpty()) { + + for (String text : getTexts().values()) { ObjectNode textNode = JsonNodeFactory.instance.objectNode(); - textNode.put("Text", getText()); + textNode.put("Text", text); body.add(textNode); } return body; } + + public Optional> handleTranslations( + RestRequestResult>> response) { + if (response.getJsonBody().isNull() || response.getJsonBody().isEmpty()) + return Optional.empty(); + + List responses = new ArrayList<>(); + getTexts() + .forEach( + (index, baseText) -> { + JsonNode jsonNode = response.getJsonBody().get(index - 1); + if (!jsonNode.has("translations")) return; + + Collection translations = new ArrayList<>(); + jsonNode + .get("translations") + .forEach(node -> translations.add(Translation.ofJSON((ObjectNode) node))); + + if (jsonNode.has("detectedLanguage")) { + JsonNode detectedLanguage = jsonNode.get("detectedLanguage"); + responses.add( + new TranslationResponse( + baseText, + DetectedLanguage.ofJSON((ObjectNode) detectedLanguage), + translations)); + return; + } + + responses.add(new TranslationResponse(baseText, translations)); + }); + + return Optional.of(responses); + } } diff --git a/src/main/java/io/github/brenoepics/at4j/data/response/TranslationResponse.java b/src/main/java/io/github/brenoepics/at4j/data/response/TranslationResponse.java index 86a03464..1907d501 100644 --- a/src/main/java/io/github/brenoepics/at4j/data/response/TranslationResponse.java +++ b/src/main/java/io/github/brenoepics/at4j/data/response/TranslationResponse.java @@ -18,6 +18,8 @@ public class TranslationResponse { */ private DetectedLanguage detectedLanguage = null; + private final String baseText; + // A collection of translations for the input text. private final Collection translations; @@ -28,7 +30,8 @@ public class TranslationResponse { * @param translations A collection of translations for the input text. */ public TranslationResponse( - DetectedLanguage detectedLanguage, Collection translations) { + String baseText, DetectedLanguage detectedLanguage, Collection translations) { + this.baseText = baseText; this.detectedLanguage = detectedLanguage; this.translations = translations; } @@ -39,7 +42,8 @@ public TranslationResponse( * * @param translations A collection of translations for the input text. */ - public TranslationResponse(Collection translations) { + public TranslationResponse(String baseText, Collection translations) { + this.baseText = baseText; this.translations = translations; } @@ -60,4 +64,13 @@ public DetectedLanguage getDetectedLanguage() { public Collection getTranslations() { return translations; } + + /** + * Returns the base texts that were translated + * + * @return the base text + */ + public String getBaseText() { + return baseText; + } } diff --git a/src/test/java/io/github/brenoepics/at4j/AzureApiTest.java b/src/test/java/io/github/brenoepics/at4j/AzureApiTest.java index 811e842e..1793bf4e 100644 --- a/src/test/java/io/github/brenoepics/at4j/AzureApiTest.java +++ b/src/test/java/io/github/brenoepics/at4j/AzureApiTest.java @@ -70,35 +70,31 @@ void translateEmptyKey() { AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("").region("test").build(); TranslateParams params = new TranslateParams("test", List.of("pt")).setSourceLanguage("en"); - CompletableFuture> translation = api.translate(params); + CompletableFuture>> translation = api.translate(params); assertThrows(CompletionException.class, translation::join); api.disconnect(); } @Test - void translateEmptyText() { - AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("test").build(); - - TranslateParams params = new TranslateParams("", List.of("pt")).setSourceLanguage("en"); - CompletableFuture> translation = api.translate(params); - Optional tr = translation.join(); - tr.ifPresent(translations -> assertEquals(0, translations.getTranslations().size())); - api.disconnect(); - } + void translateHelloWorld() { + String azureKey = System.getenv("AZURE_KEY"); + String region = System.getenv("AZURE_REGION"); + Assumptions.assumeTrue( + azureKey != null && region != null, "Azure Credentials are null, skipping the test"); + Assumptions.assumeTrue( + !azureKey.isEmpty() && !region.isEmpty(), "Azure Credentials are empty, skipping the test"); - @Test - void translateEmptySourceLanguage() { - AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("test").build(); + AzureApiBuilder builder = new AzureApiBuilder().setKey(azureKey).region(region); + AzureApi api = builder.build(); - TranslateParams params = new TranslateParams("", List.of("pt")); - CompletableFuture> translation = api.translate(params); - Optional tr = translation.join(); - tr.ifPresent(translations -> assertEquals(0, translations.getTranslations().size())); - api.disconnect(); + TranslateParams params = new TranslateParams("Hello World!", List.of("pt", "es")); + Optional> translate = api.translate(params).join(); + assertTrue(translate.isPresent()); + assertEquals(2, translate.get().get(0).getTranslations().size()); } @Test - void translateHelloWorld() { + void translateMultiText() { String azureKey = System.getenv("AZURE_KEY"); String region = System.getenv("AZURE_REGION"); Assumptions.assumeTrue( @@ -109,10 +105,12 @@ void translateHelloWorld() { AzureApiBuilder builder = new AzureApiBuilder().setKey(azureKey).region(region); AzureApi api = builder.build(); - TranslateParams params = new TranslateParams("Hello World!", List.of("pt", "es")); - Optional translate = api.translate(params).join(); + TranslateParams params = + new TranslateParams(List.of("Hello World!", "How are you?"), List.of("pt", "es")); + Optional> translate = api.translate(params).join(); assertTrue(translate.isPresent()); - assertEquals(2, translate.get().getTranslations().size()); + + assertEquals(2, translate.get().size()); } @Test diff --git a/src/test/java/io/github/brenoepics/at4j/core/AzureApiImplTest.java b/src/test/java/io/github/brenoepics/at4j/core/AzureApiImplTest.java index 86eb82e1..45527178 100644 --- a/src/test/java/io/github/brenoepics/at4j/core/AzureApiImplTest.java +++ b/src/test/java/io/github/brenoepics/at4j/core/AzureApiImplTest.java @@ -8,7 +8,9 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; +import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -28,7 +30,7 @@ public void setup() { @Test void returnsEmptyOnInvalidInput() { - translateParams.setText(null); + translateParams.setTexts(Collections.emptyList()); azureApi .translate(translateParams) .whenComplete( @@ -39,7 +41,7 @@ void returnsEmptyOnInvalidInput() { } }); - CompletableFuture> response = azureApi.translate(translateParams); + CompletableFuture>> response = azureApi.translate(translateParams); assertFalse(response.join().isPresent()); } diff --git a/src/test/java/io/github/brenoepics/at4j/data/request/TranslateParamsTest.java b/src/test/java/io/github/brenoepics/at4j/data/request/TranslateParamsTest.java index f9510980..f6783819 100644 --- a/src/test/java/io/github/brenoepics/at4j/data/request/TranslateParamsTest.java +++ b/src/test/java/io/github/brenoepics/at4j/data/request/TranslateParamsTest.java @@ -6,10 +6,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import static org.junit.jupiter.api.Assertions.*; @@ -18,8 +15,8 @@ class TranslateParamsTest { @Test void shouldSetAndGetText() { TranslateParams params = new TranslateParams("Hello", List.of("fr")); - params.setText("Bonjour"); - assertEquals("Bonjour", params.getText()); + params.setTexts(Collections.singleton("Bonjour")); + assertEquals("Bonjour", params.getTexts().get(1)); } @Test diff --git a/src/test/java/io/github/brenoepics/at4j/data/response/TranslationResponseTest.java b/src/test/java/io/github/brenoepics/at4j/data/response/TranslationResponseTest.java index 26f811af..8acd03d3 100644 --- a/src/test/java/io/github/brenoepics/at4j/data/response/TranslationResponseTest.java +++ b/src/test/java/io/github/brenoepics/at4j/data/response/TranslationResponseTest.java @@ -16,7 +16,7 @@ class TranslationResponseTest { void createsTranslationResponseWithDetectedLanguageAndTranslations() { DetectedLanguage detectedLanguage = new DetectedLanguage("en", 1.0f); Translation translation = new Translation("pt", "Olá, mundo!"); - TranslationResponse response = new TranslationResponse(detectedLanguage, List.of(translation)); + TranslationResponse response = new TranslationResponse(translation.getText(), detectedLanguage, List.of(translation)); assertEquals(detectedLanguage, response.getDetectedLanguage()); assertEquals(1, response.getTranslations().size()); @@ -26,7 +26,7 @@ void createsTranslationResponseWithDetectedLanguageAndTranslations() { @Test void createsTranslationResponseWithTranslationsOnly() { Translation translation = new Translation("pt", "Olá, mundo!"); - TranslationResponse response = new TranslationResponse(List.of(translation)); + TranslationResponse response = new TranslationResponse(translation.getText(), List.of(translation)); assertNull(response.getDetectedLanguage()); assertEquals(1, response.getTranslations().size()); @@ -35,7 +35,7 @@ void createsTranslationResponseWithTranslationsOnly() { @Test void returnsEmptyTranslationsWhenNoneProvided() { - TranslationResponse response = new TranslationResponse(Collections.emptyList()); + TranslationResponse response = new TranslationResponse("", Collections.emptyList()); assertEquals(0, response.getTranslations().size()); } }