package foo.bar.rest; import java.io.*; import javax.servlet.http.*; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.UnsupportedEncodingException; import java.net.Authenticator; import java.net.HttpURLConnection; import java.net.URL; import java.net.URLEncoder; import java.text.SimpleDateFormat; import java.util.Date; import java.util.logging.Level; import java.util.logging.Logger; /** * RestProxyServlet * * Use this servlet for javascript cross-domain access to rest services * * When using this, please add servlet mapping to /WEB-INF/web.xml as follows * RestProxy foo.bar.RestProxyServlet RestProxy /restproxy * */ public class RestProxyServlet extends HttpServlet { @Override public void doGet(HttpServletRequest request, HttpServletResponse response) { String[][] queryParams = new String[][]{}; RestConnection conn = new RestConnection(request.getParameter("url"), queryParams); String[][] headers = new String[][]{{"Accept", "application/json"}}; try { String result = conn.get(headers).getDataAsString(); response.getWriter().println(result); } catch (IOException ex) { Logger.getLogger(RestProxyServlet.class.getName()).log(Level.SEVERE, null, ex); } } @Override public void doPost(HttpServletRequest request, HttpServletResponse response) { StringBuffer bodyContent = null; try { String method = request.getParameter("method"); if(method == null) throw new IOException("Method parameter missing in the request."); BufferedReader in = request.getReader(); String line = null; while ((line = in.readLine()) != null) { if (bodyContent == null) { bodyContent = new StringBuffer(); } bodyContent.append(line); } RestConnection conn = new RestConnection(request.getParameter("url")); if(method.equalsIgnoreCase("POST") || method.equalsIgnoreCase("PUT")) { String[][] headers = new String[][]{{"Content-Type", "application/json"}}; RestResponse result = null; if(method.equalsIgnoreCase("POST")) result = conn.post(headers, bodyContent.toString()); else result = conn.put(headers, bodyContent.toString()); response.getWriter().println(result); } else if(method.equalsIgnoreCase("DELETE")) { RestResponse result = conn.delete(); response.getWriter().println(result); } else { throw new IOException("Undefined method parameter in the request: "+method); } } catch (Exception e) { Logger.getLogger(RestProxyServlet.class.getName()).log(Level.SEVERE, null, e); } } public class RestConnection { private HttpURLConnection conn; private String date; public RestConnection(String baseUrl) { this(baseUrl, null); } /** Creates a new instance of RestConnection */ public RestConnection(String baseUrl, String[][] params) { this(baseUrl, null, params); } /** Creates a new instance of RestConnection */ public RestConnection(String baseUrl, String[][] pathParams, String[][] params) { try { String urlStr = baseUrl; if (pathParams != null && pathParams.length > 0) { urlStr = replaceTemplateParameters(baseUrl, pathParams); } URL url = new URL(encodeUrl(urlStr, params)); conn = (HttpURLConnection) url.openConnection(); conn.setDoInput(true); conn.setDoOutput(true); conn.setUseCaches(false); conn.setDefaultUseCaches(false); conn.setAllowUserInteraction(true); SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z"); date = format.format(new Date()); conn.setRequestProperty("Date", date); } catch (Exception ex) { Logger.getLogger(RestConnection.class.getName()).log(Level.SEVERE, null, ex); } } public void setAuthenticator(Authenticator authenticator) { Authenticator.setDefault(authenticator); } public String getDate() { return date; } public RestResponse get() throws IOException { return get(null); } public RestResponse get(String[][] headers) throws IOException { conn.setRequestMethod("GET"); return connect(headers, null); } public RestResponse put(String[][] headers) throws IOException { return put(headers, (InputStream) null); } public RestResponse put(String data) throws IOException { return put(null, data); } public RestResponse put(InputStream data) throws IOException { return put(null, data); } public RestResponse put(String[][] headers, String data) throws IOException { conn.setRequestMethod("PUT"); return connect(headers, new ByteArrayInputStream(data.getBytes("UTF-8"))); } public RestResponse put(String[][] headers, InputStream data) throws IOException { conn.setRequestMethod("PUT"); return connect(headers, data); } public RestResponse post(String data) throws IOException { return post(null, data); } public RestResponse post(InputStream data) throws IOException { return post(null, data); } public RestResponse post(String[][] headers, String data) throws IOException { conn.setRequestMethod("POST"); return connect(headers, new ByteArrayInputStream(data.getBytes("UTF-8"))); } public RestResponse post(String[][] headers, InputStream data) throws IOException { conn.setRequestMethod("POST"); return connect(headers, data); } /** * Used by post method whose contents are like form input */ public RestResponse post(String[][] params) throws IOException { return post(null, params); } /** * Used by post method whose contents are like form input */ public RestResponse post(String[][] headers, String[][] params) throws IOException { conn.setRequestMethod("POST"); conn.setRequestProperty("ContentType", "application/x-www-form-urlencoded"); String data = encodeParams(params); return connect(headers, new ByteArrayInputStream(data.getBytes("UTF-8"))); } public RestResponse delete() throws IOException { return delete(null); } public RestResponse delete(String[][] headers) throws IOException { conn.setRequestMethod("DELETE"); return connect(headers, null); } /** * @param baseUrl * @param params * @return response */ private RestResponse connect(String[][] headers, InputStream data) throws IOException { try { // Send data setHeaders(headers); String method = conn.getRequestMethod(); byte[] buffer = new byte[1024]; int count = 0; if (method.equals("PUT") || method.equals("POST")) { if (data != null) { conn.setDoOutput(true); OutputStream os = conn.getOutputStream(); while ((count = data.read(buffer)) != -1) { os.write(buffer, 0, count); } os.flush(); } } RestResponse response = new RestResponse(); InputStream is = conn.getInputStream(); while ((count = is.read(buffer)) != -1) { response.write(buffer, 0, count); } response.setResponseCode(conn.getResponseCode()); response.setResponseMessage(conn.getResponseMessage()); response.setContentType(conn.getContentType()); response.setContentEncoding(conn.getContentEncoding()); response.setLastModified(conn.getLastModified()); return response; } catch (Exception e) { String errMsg = "Cannot connect to :" + conn.getURL(); try { BufferedReader rd = new BufferedReader(new InputStreamReader(conn.getErrorStream())); String line; StringBuffer buf = new StringBuffer(); while ((line = rd.readLine()) != null) { buf.append(line); System.out.print(line); } errMsg = buf.toString(); } finally { throw new IOException(errMsg); } } } private String replaceTemplateParameters(String baseUrl, String[][] pathParams) { String url = baseUrl; if (pathParams != null) { for (int i = 0; i < pathParams.length; i++) { String key = pathParams[i][0]; String value = pathParams[i][1]; if (value == null) { value = ""; } url = url.replace(key, value); } } return url; } private String encodeUrl(String baseUrl, String[][] params) { return baseUrl + encodeParams(params); } private String encodeParams(String[][] params) { String p = ""; if (params != null) { for (int i = 0; i < params.length; i++) { String key = params[i][0]; String value = params[i][1]; if (value != null) { try { p += key + "=" + URLEncoder.encode(value, "UTF-8") + "&"; } catch (UnsupportedEncodingException ex) { Logger.getLogger(RestConnection.class.getName()).log(Level.SEVERE, null, ex); } } } if (p.length() > 0) { p = "?" + p.substring(0, p.length() - 1); } } return p; } private void setHeaders(String[][] headers) { if (headers != null) { for (int i = 0; i < headers.length; i++) { conn.setRequestProperty(headers[i][0], headers[i][1]); } } } } public class RestResponse { private ByteArrayOutputStream os; private String contentType = "text/plain"; private String contentEncoding; private int responseCode; private String responseMsg; private long lastModified; public RestResponse() { os = new ByteArrayOutputStream(); } public RestResponse(byte[] bytes) throws IOException { this(); byte[] buffer = new byte[1024]; int count = 0; ByteArrayInputStream bis = new ByteArrayInputStream(bytes); while ((count = bis.read(buffer)) != -1) { write(buffer, 0, count); } } public void setContentType(String contentType) { this.contentType = contentType; } public String getContentType() { return contentType; } public void setContentEncoding(String contentEncoding) { this.contentEncoding = contentEncoding; } public void setResponseMessage(String msg) { this.responseMsg = msg; } public String getResponseMessage() { return responseMsg; } public void setResponseCode(int code) { this.responseCode = code; } public int getResponseCode() { return responseCode; } public void setLastModified(long lastModified) { this.lastModified = lastModified; } public long getLastModified() { return lastModified; } public void write(byte[] bytes, int start, int length) { os.write(bytes, start, length); } public byte[] getDataAsByteArray() { return os.toByteArray(); } public String getDataAsString() { try { return os.toString("UTF-8"); } catch (Exception ex) { Logger.getLogger(RestConnection.class.getName()).log(Level.SEVERE, null, ex); } return null; } public OutputStream getOutputStream() { return os; } } }