开发者

ThreadSafeClientConnManager not multithreading

I've been asked to fix up a Servlet that sits in between two applications. It's purpose is to convert SAML authorisation request to and from SAML v2.0 / SAML 1.1. So it:

  • receives a HTTP SAML v2.0 authorisation request from one app
  • converts the request into SAML v1.1
  • sends the request to the second app
  • receives the SAML v1.1 response from the second app
  • converts the response into SAML v2.0
  • sends the response back to the first app

Don't worry about the SAML stuff, it's the HTTP stuff that's the problem. The code does it's job, but it suffers greatly un开发者_如何转开发der load. I've found through testing that even though the code utilises a ThreadSafeClientConnManager from Apache httpcomponents, each request that hits the servlet is being handled in a single-threaded manner. To put it more accurately, the second the code reaches the HTTPClient.execute() method the first thread to create a connection will run through the entire rest of the process before any other thread begins working. For example:

  • 15 requests hit the servlet at the same time
  • servlet spawns 15 threads to service the requests
  • all 15 threads retrieve their respective request data
  • all 15 threads convert their respective data from SAML v2.0 to SAML v1.1
  • Thread 1 calls HTTPClient.execute()
    • Thread 1 sends the request on to the second app
    • Thread 1 receives the response from the second app
    • Thread 1 decodes the response and converts it from SAML v1.1 to SAML v2.0
    • Thread 1 sends the response back to the first app
  • Thread 2 calls HTTPClient.execute()
  • ... and so on ...

I've included the code below. From what I can see all the necessary items are present. Can anyone see anything wrong or missing that would prevent this servlet from servicing multiple requests at the same time?

public class MappingServlet extends HttpServlet {

private HttpClient client;
private String pdp_url;

public void init() throws ServletException {
    org.opensaml.Configuration.init();
    pdp_url = getInitParameter("pdp_url");

    ThreadSafeClientConnManager cm = new ThreadSafeClientConnManager();
    HttpRoute route = new HttpRoute(new HttpHost(pdp_url));
    cm.setDefaultMaxPerRoute(100);
    cm.setMaxForRoute(route, 100);
    cm.setMaxTotal(100);
    client = new DefaultHttpClient(cm);
}

protected void doPost(HttpServletRequest request, HttpServletResponse response)
    throws ServletException, IOException {

    long threadId = Thread.currentThread().getId();
    log.debug("[THREAD " + threadId + "] client request received");

    // Get the input entity (SAML2)
    InputStream in = null;
    byte[] query11 = null;
    try {
        in = request.getInputStream();
        query11 = Saml2Requester.convert(in);
        log.debug("[THREAD " + threadId + "] client request SAML11:\n" + query11);
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]\n", ex);
        return;
    } finally {
        if (in != null) {
            try {
                in.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]\n", ioe);
            }
        }
    }

    // Proxy the request to the PDP
    HttpPost httpPost = new HttpPost(pdp_url);
    ByteArrayEntity entity = new ByteArrayEntity(query11);
    httpPost.setEntity(entity);
    HttpResponse httpResponse = null;
    try {
        httpResponse = client.execute(httpPost);
    } catch (IOException ioe) {
        log.error("[THREAD " + threadId + "]\n", ioe);
        httpPost.abort();
        return;
    }

    int sc = httpResponse.getStatusLine().getStatusCode();
    if (sc != HttpStatus.SC_OK) {
        log.error("[THREAD " + threadId + "] Bad response from PDP: " + sc);
        httpPost.abort();
        return;
    }

    // Get the response back from the PDP
    InputStream in2 = null;
    byte[] resp = null;
    try {
        HttpEntity entity2 = httpResponse.getEntity();
        in2 = entity2.getContent();
        resp = Saml2Requester.consumeStream(in2);
        EntityUtils.consumeStream(in2);
        log.debug("[THREAD " + threadId + "] client response received, SAML11: " + resp);
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]", ex);
        httpPost.abort();
        return;
    } finally {
        if (in2 != null) {
            try {
                in2.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]", ioe);
            }
        }
    }

    // Convert the response from SAML1.1 to SAML2 and send back
    ByteArrayInputStream respStream = null;
    byte[] resp2 = null;
    try {
        respStream = new ByteArrayInputStream(resp);
        resp2 = Saml2Responder.convert(respStream);
    } finally {
        if (respStream != null) {
            try {
                respStream.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]", ioe);
            }
        }
    }
    log.debug("[THREAD " + threadId + "] client response SAML2: " + resp2);

    OutputStream os2 = null;
    try {
        os2 = response.getOutputStream();
        os2.write(resp2.getBytes());
        log.debug("[THREAD " + threadId + "] client response forwarded");
    } catch (IOException ex) {
        log.error("[THREAD " + threadId + "]\n", ex);
        return;
    } finally {
        if (os2 != null) {
            try {
                os2.close();
            } catch (IOException ioe) {
                log.error("[THREAD " + threadId + "]\n", ioe);
            }
        }
    }
}

public void destroy() {
    client.getConnectionManager().shutdown();
    super.destroy();
}

}

Thanks in advance!


The HttpClient.execute() does not return until the called server send out all of the http headers. Your code works fine. I think the called service is the real bottleneck. I've created a simple proof of concept code (based on your snippet) for it:

import java.io.IOException;

import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.conn.routing.HttpRoute;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.conn.tsccm.ThreadSafeClientConnManager;

public class MyHttpClient {

    private static final String url = "http://localhost:8080/WaitServlet";

    private final DefaultHttpClient client;

    public MyHttpClient() {
        final ThreadSafeClientConnManager cm = 
                new ThreadSafeClientConnManager();
        final HttpRoute route = new HttpRoute(new HttpHost(url));
        cm.setDefaultMaxPerRoute(100);
        cm.setMaxForRoute(route, 100);
        cm.setMaxTotal(100);
        client = new DefaultHttpClient(cm);
    }

    public void doPost() {
        final HttpPost httpPost = new HttpPost(url);

        HttpResponse httpResponse;
        try {
            httpResponse = client.execute(httpPost);
        } catch (final IOException ioe) {
            ioe.printStackTrace();
            httpPost.abort();
            return;
        }

        final StatusLine statusLine = httpResponse.getStatusLine();
        System.out.println("status: " + statusLine);
        final int statusCode = statusLine.getStatusCode();
        if (statusCode != HttpStatus.SC_OK) {
            httpPost.abort();
            return;
        }
    }
}

And a test:

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Test;

public class HttpClientTest {

    @Test
    public void test2() throws Exception {
        final ExecutorService executorService = 
                Executors.newFixedThreadPool(16);

        final MyHttpClient myHttpClient = new MyHttpClient();

        for (int i = 0; i < 8; i++) {
            final Runnable runnable = new Runnable() {

                @Override
                public void run() {
                    myHttpClient.doPost();
                }
            };
            executorService.execute(runnable);
        }

        executorService.shutdown();
        executorService.awaitTermination(150, TimeUnit.SECONDS);
    }
}

Finally, the called WaitServlet:

import java.io.IOException;
import java.io.PrintWriter;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class WaitServlet extends HttpServlet {
    private static final long serialVersionUID = 1L;

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp)
            throws ServletException, IOException {
        try {
            Thread.sleep(30 * 1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        final PrintWriter writer = resp.getWriter();
        writer.println("wait end");
    }
}
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜