Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize StreamUtil and fix the LimitingExecutor #550

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,24 @@

package com.linkedin.avroutil1.builder.util;

import java.util.Collection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;


Expand Down Expand Up @@ -72,36 +78,10 @@ private StreamUtil() {
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism,
int batchSize) {
if (parallelism <= 0 || batchSize <= 0) {
throw new IllegalArgumentException("Parallelism and batch size must be >= 1");
}

return Collectors.collectingAndThen(Collectors.toList(), list -> {
if (list.isEmpty()) {
return Stream.empty();
}

if (parallelism == 1 || list.size() <= batchSize) {
return list.stream().map(mapper);
}

final Executor limitingExecutor = new LimitingExecutor(parallelism);
final int batchCount = (list.size() - 1) / batchSize;
return IntStream.rangeClosed(0, batchCount)
.mapToObj(batch -> {
int startIndex = batch * batchSize;
int endIndex = (batch == batchCount) ? list.size() : (batch + 1) * batchSize;
return list.subList(startIndex, endIndex);
})
.map(batch -> CompletableFuture.supplyAsync(() -> batch.stream().map(mapper).collect(Collectors.toList()),
limitingExecutor))
.map(CompletableFuture::join)
.flatMap(Collection::stream);
});
return new ParallelStreamCollector<>(mapper, parallelism, batchSize);
}

private final static class LimitingExecutor implements Executor {

private final Semaphore _limiter;

private LimitingExecutor(int maxParallelism) {
Expand All @@ -112,12 +92,76 @@ private LimitingExecutor(int maxParallelism) {
public void execute(Runnable command) {
try {
_limiter.acquire();
WORK_EXECUTOR.execute(command);
WORK_EXECUTOR.execute(() -> {
try {
command.run();
} finally {
_limiter.release();
}
});
} catch (InterruptedException e) {
throw new RuntimeException(e);
} finally {
_limiter.release();
}
}
}

private static final class ParallelStreamCollector<T, R> implements Collector<T, LinkedList<T>, Stream<R>> {
private final int _batchSize;
private final Function<T, R> _mapper;
private final Executor _executor;
private final List<CompletableFuture<List<R>>> _futures = new ArrayList<>();

private ParallelStreamCollector(Function<T, R> mapper, int parallelism, int batchSize) {
if (parallelism <= 0 || batchSize <= 0) {
throw new IllegalArgumentException("Parallelism and batch size must be > 0");
}
_mapper = mapper;
_batchSize = batchSize;
_executor = new LimitingExecutor(parallelism);
}

@Override
public Supplier<LinkedList<T>> supplier() {
return LinkedList::new;
}

public BiConsumer<LinkedList<T>, T> accumulator() {
return this::accumulate;
}

private void accumulate(LinkedList<T> list, T element) {
if (list.size() >= _batchSize) {
List<T> listCopy = new ArrayList<>(list);
_futures.add(CompletableFuture.supplyAsync(() -> listCopy.stream().map(_mapper).collect(Collectors.toList()),
_executor));
list.clear();
}
list.add(element);
}

@Override
public BinaryOperator<LinkedList<T>> combiner() {
return (left, right) -> {
left.addAll(right);
return left;
};
}

@Override
public Function<LinkedList<T>, Stream<R>> finisher() {
return list -> {
if (!list.isEmpty()) {
_futures.add(
CompletableFuture.supplyAsync(() -> list.stream().map(_mapper).collect(Collectors.toList()), _executor));
}

return _futures.stream().flatMap(future -> future.join().stream());
};
}

@Override
public Set<Characteristics> characteristics() {
return Collections.singleton(Characteristics.UNORDERED);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2024 LinkedIn Corp.
* Licensed under the BSD 2-Clause License (the "License").
* See License in the project root for license information.
*/

package com.linkedin.avroutil1.builder.util;

import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.Test;


/**
* This is to test {@link StreamUtil}
*/
public class StreamUtilTest {
@Test
public void testParallelStreaming() throws Exception {
int result = IntStream.rangeClosed(1, 100)
.boxed()
.collect(StreamUtil.toParallelStream(x -> x * x, 3, 4))
.reduce(0, Integer::sum);

int expected = IntStream.rangeClosed(1, 100).map(x -> x * x).sum();

Assert.assertEquals(result, expected);
}
}
Loading