diff --git a/simpleclient_servlet/src/main/java/io/prometheus/client/filter/MetricsFilter.java b/simpleclient_servlet/src/main/java/io/prometheus/client/filter/MetricsFilter.java index 3b61a0ed3..a9bcb658e 100644 --- a/simpleclient_servlet/src/main/java/io/prometheus/client/filter/MetricsFilter.java +++ b/simpleclient_servlet/src/main/java/io/prometheus/client/filter/MetricsFilter.java @@ -3,6 +3,8 @@ import io.prometheus.client.Counter; import io.prometheus.client.Histogram; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; @@ -162,7 +164,7 @@ public void init(FilterConfig filterConfig) throws ServletException { } @Override - public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + public void doFilter(ServletRequest servletRequest, final ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { if (!(servletRequest instanceof HttpServletRequest)) { filterChain.doFilter(servletRequest, servletResponse); return; @@ -172,17 +174,37 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo String path = request.getRequestURI(); - String components = getComponents(path); - String method = request.getMethod(); - Histogram.Timer timer = histogram + final String components = getComponents(path); + final String method = request.getMethod(); + final Histogram.Timer timer = histogram .labels(components, method) .startTimer(); + boolean isAsync = false; try { filterChain.doFilter(servletRequest, servletResponse); + isAsync = servletRequest.isAsyncStarted(); + if (isAsync) { + servletRequest.getAsyncContext().addListener(new AsyncListener() { + volatile boolean done = false; + private void meter() { + if (!done) { + done = true; + timer.observeDuration(); + statusCounter.labels(components, method, getStatusCode(servletResponse)).inc(); + } + } + @Override public void onStartAsync(AsyncEvent asyncEvent) { } + @Override public void onComplete(AsyncEvent asyncEvent) { meter(); } + @Override public void onError(AsyncEvent asyncEvent) { meter(); } + @Override public void onTimeout(AsyncEvent asyncEvent) { meter(); } + }); + } } finally { - timer.observeDuration(); - statusCounter.labels(components, method, getStatusCode(servletResponse)).inc(); + if (!isAsync) { + timer.observeDuration(); + statusCounter.labels(components, method, getStatusCode(servletResponse)).inc(); + } } }