A Unique Method of Authenticating against App-Managed Userlist

I have a project that uses Amazon’s SimpleDB service for data storage. Being a Java programmer, I have become fond of using JPA (Java Persistence Architecture) implementations. In some cases, I’ve used EclipseLink, but more recently I’ve been playing with SimpleJPA. This is a partial JPA implementation on top of SimpleDB. The benefits include writing value objects with minimal annotations to indicate relationships.

Anyway, enough about why I do it. Since my user list is also stored in JPA entities, I’d like to tie this into the container managed authentication. The web app I’m writing is being deployed to tomcat and so realms are used to define a authentication provider. Tomcat provides several realms that hook into a JDBC Database, JAAS, JNDI Datasource and more. In my case, I wanted to rely in data access via JPA. Before discussing the challenges, I should point out that in a Java web app container, there are different class loaders to contend with. The container has its own classloader, and each web application has its own. My application obviously contains all of the supporting jars for SimpleJPA and my value objects. Since authentication is being handled by the container, it doesn’t have access to my app’s classloader. So, I’d need to deploy about 12 jar files into the tomcat/lib directory to make them available to the container. One of those contains my value objects and could change in the future. I don’t think that’s a very nice deployment strategy (deploying a war, and then a separate jar for each software update).

To solve this problem, I had to come up with a way to write my own Realm with as few dependencies on my application as possible. What I came up with is a socket listener, running on a dedicated socket, within my web application. It only accepts connections from localhost, so it is not likely to become a security concern. The socket listener receives a username and returns username,password,role1,role2,… as a string. That is the contract between my web application and the authentication realm. The realm interfaces with the socket listener and uses that to get information about the user trying to authenticate, which is converts to the object format used within realms in tomcat.

The code for the socket listener is fairly simple;

package org.scalabletype.util;

import java.io.InputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.net.ServerSocket;
import java.net.UnknownHostException;

import javax.persistence.EntityManager;
import javax.persistence.Query;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.scalabletype.data.DataHelper;
import org.scalabletype.data.User;

/**
 * This class listens on a port, receives a username, looks up user record, then responds with data.
 */
public class AuthServer extends Thread {
	private static Log logger = LogFactory.getLog(AuthServer.class);
	public static final int AUTH_SOCKET = 2000;

	public AuthServer() { }

	public void run() {
		while (!isInterrupted()) {
			try {
				ServerSocket ss = new ServerSocket(AUTH_SOCKET);
				while (!isInterrupted()) {
					Socket sock = ss.accept();
					try {
						// confirm connection from localhost only
						InetAddress addr = sock.getInetAddress();
						if (addr.getHostName().equals("localhost")) {
							// get user to authenticate
							InputStream iStr = sock.getInputStream();
							byte [] buf = new byte[1024];
							int bytesRead = iStr.read(buf);
							String username = new String(buf, 0, bytesRead);
							logger.info("username to authenticate:"+username);

							// fetch user from JPA
							EntityManager em = DataHelper.getEntityManager();
							Query query = em.createQuery("select object(o) from User o where o.username = :name");
							query.setParameter("name", username);
							User usr = (User)query.getSingleResult();

							// return user data, or nothing
							OutputStream oStr = sock.getOutputStream();
							logger.info("got connection, going to respond");
							if (usr != null) {
								StringBuilder ret = new StringBuilder();
								ret.append(usr.getUsername());
								ret.append(",");
								ret.append(usr.getPassword());
								ret.append(",");
								ret.append(usr.getAuthGroups());
								oStr.write(ret.toString().getBytes());
							}
							oStr.flush();
						}
						sock.close();
					} catch (Exception ex) {
						logger.error("Some problem handling the request", ex);
					}
				}
			} catch (Exception ex) {
				logger.error("problem accepting connection. will keep going.", ex);
			}
		}
	}
}

The socket listener needs to be invoked when the web application is initialized and a ServletContextListener is a good place to do that;

public class ScalableTypeStarter implements ServletContextListener {
	private AuthServer auth;

	public void contextInitialized(ServletContextEvent evt) {
		// init data persistence layer
		DataHelper.initDataHelper(evt.getServletContext());

		// start authorization socket listener
		auth = new AuthServer();
		auth.start();
	}

	public void contextDestroyed(ServletContextEvent evt) {
		if (auth != null) {
			auth.interrupt();
			auth = null;
		}
	}
}

Here is the code for my realm, which is packaged by itself into a jar, and deployed (once) into the tomcat/lib directory.

package org.scalabletype.util;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.UnknownHostException;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.catalina.Group;
import org.apache.catalina.Role;
import org.apache.catalina.User;
import org.apache.catalina.UserDatabase;
import org.apache.catalina.realm.GenericPrincipal;
import org.apache.catalina.realm.RealmBase;

/**
 * This realm authenticates against user data via the socket listener.
 *
 */
public class UserRealm extends RealmBase {
	public static final int AUTH_SOCKET = 2000;

    protected final String info = "org.scalabletype.util.UserRealm/1.0";
    protected static final String name = "UserRealm";

    /**
     * Return descriptive information about this Realm implementation and
     * the corresponding version number, in the format
     * <code>&lt;description&gt;/&lt;version&gt;</code>.
     */
    public String getInfo() {
        return info;
    }

    /**
     * Return <code>true</code> if the specified Principal has the specified
     * security role, within the context of this Realm; otherwise return
     * <code>false</code>. This implementation returns <code>true</code>
     * if the <code>User</code> has the role, or if any <code>Group</code>
     * that the <code>User</code> is a member of has the role. 
     *
     * @param principal Principal for whom the role is to be checked
     * @param role Security role to be checked
     */
    public boolean hasRole(Principal principal, String role) {
        if (principal instanceof GenericPrincipal) {
            GenericPrincipal gp = (GenericPrincipal)principal;
            if(gp.getUserPrincipal() instanceof User) {
                principal = gp.getUserPrincipal();
            }
        }
        if (!(principal instanceof User) ) {
            //Play nice with SSO and mixed Realms
            return super.hasRole(principal, role);
        }
        if ("*".equals(role)) {
            return true;
        } else if(role == null) {
            return false;
        }
        User user = (User)principal;
        UserInfo usr = findUser(user.getFullName());
        if (usr == null) {
            return false;
        } 
        for (String group : usr.groups) {
			if (role.equals(group)) return true;
		}
        return false;
    }
		
    /**
     * Return a short name for this Realm implementation.
     */
    protected String getName() {
        return name;
    }

    /**
     * Return the password associated with the given principal's user name.
     */
    protected String getPassword(String username) {
        UserInfo user = findUser(username);

        if (user == null) {
            return null;
        } 

        return (user.password);
    }

    /**
     * Return the Principal associated with the given user name.
     */
    protected Principal getPrincipal(String username) {
        UserInfo user = findUser(username);
        if(user == null) {
            return null;
        }

        List roles = new ArrayList();
        for (String group : user.groups) {
            roles.add(group);
        }
        return new GenericPrincipal(this, username, user.password, roles);
    }

	private UserInfo findUser(String username) {
		UserInfo user = new UserInfo();
		try {
			Socket sock = new Socket("localhost", AUTH_SOCKET);
			OutputStream oStr = sock.getOutputStream();
			oStr.write(username.getBytes());
			oStr.flush();
			InputStream iStr = sock.getInputStream();
			byte [] buf = new byte[4096];
			int len = iStr.read(buf);
			if (len == 0) {
				return null;
			}
			String [] data = new String(buf, 0, len).split(",");
			user.username = data[0];
			user.password = data[1];
			ArrayList<String> groups = new ArrayList<String>();
			for (int i=2; i<data.length; i++) {
				groups.add(data[i]);
			}
			user.groups = groups;
		} catch (UnknownHostException ex) {
			ex.printStackTrace();
		} catch (IOException ex) {
			ex.printStackTrace();
		}
		return user;
	}

	class UserInfo {
		String username;
		String password;
		List<String> groups;
	}
}

The web app’s context.xml contains this line to configure the realm;

<Realm className="org.scalabletype.util.UserRealm" resourceName="ScalableTypeAuth"/>

Flash Socket Code and crossdomain Policy Serving

I’ve just spent the past day trying to get my flash app talking to another device on my network via socket 23. I found some sample telnet code (which operates on port 23) and allowed me to “talk” to the RFID reader. It worked fine as a new project in Flex Builder and being served from a local file. The moment I served the application from a web server (tomcat) on my laptop, I get crossdomain issues. Flash won’t open a socket that is different from the one that served your application unless that socket authorizes it. I will spare you the details that took many hours of my day. If you’re trying to talk to another web server on a different port, no problem.. just put the crossdomain.xml file on that server that authorizes the connection. In this case, I was trying to connect to another host and another port (which runs telnet, not http). The RFID reader can’t be modified to serve up a crossdomain.xml file, so I had to get creative.

My solution was to run a TCP proxy on my web server machine that proxied requests to the RFID reader. I made it listen on port 8023 and forward requests to 23 on the RFID reader. This was the start because I still got errors about that localhost:8023  not being authorized. It turns out that when you try the connection, flash connects to the socket and sends 23 bytes which contain “<policy-file-request/>”. Flash expects whatever is running at that port to respond with the policy string (that would have been in the crossdomain.xml file). So, I modified this little proxy class I got off the internet to recognize the proxy request and respond with a proxy string (null terminated.. that is very important!). Once I had this set up right, I was able to communicate from my flash app to my RFID reader. Not the most elegant solution, but this is something temporary for a demo.

 

To run the code below, compile with javac and invoke “java -classpath <class.file.location> ProxyThread 8023 192.168.1.39 23”. Those options are what I used to talk to my RFID reader, but you’ll likely use different values.

import java.net.*;
import java.io.*;
 
/*
  Java Transparent Proxy
  Copyright (C) 1999 by Didier Frick (http://www.dfr.ch/)
  This software is provided under the GNU general public license (http://www.gnu.org/copyleft/gpl.html).
*/

public class ProxyThread extends Thread {
     protected class StreamCopyThread extends Thread {
	private Socket inSock;
	private Socket outSock;
	private boolean done=false;
	private StreamCopyThread peer;
	private boolean inFromLocal;	// in from local port
	private OutputStream out;
 
	private String policy = "<cross-domain-policy>\n<allow-access-from domain=\"*\" to-ports=\"8023\"/>\n</cross-domain-policy>";
 
	public StreamCopyThread(Socket inSock, Socket outSock, boolean in) {
	    this.inSock=inSock;
	    this.outSock=outSock;
	    this.inFromLocal = in;
	}
 
	public void sendPolicy() {
		try {
			out.write(policy.getBytes());
			System.err.println("Sent policy");
		} catch (IOException ex) {
			System.err.println("Error sending policy file");
		}
	}
 
	public void run() {
	    byte[] buf=new byte[bufSize];
	    int count=-1;
	    try {
		InputStream in=inSock.getInputStream();
		out=outSock.getOutputStream();
		try {
		    while(((count=in.read(buf))>0)&&!isInterrupted()) {
		    	if (inFromLocal && count==23 && new String(buf).startsWith("<policy-file-request/>")) {
				// send policy file back.. don't forward this to other port
				System.err.println("Got policy request");
				peer.sendPolicy();
			}
			else {
				out.write(buf,0,count);
				//System.err.println(count+" bytes "+(inFromLocal?"sent":"received"));
			}
		    }
		} catch(Exception xc) {
		    if(debug) {
			// FIXME
			// It's very difficult to sort out between "normal"
			// exceptions (occuring when one end closes the connection
			// normally), and "exceptional" exceptions (when something
			// really goes wrong)
			// Therefore we only log exceptions occuring here if the debug flag
			// is true, in order to avoid cluttering up the log.
			err.println(header+":"+xc);
			xc.printStackTrace();
		    }
		} finally {
		    // The input and output streams will be closed when the sockets themselves
		    // are closed.
		    out.flush();
		}
	    } catch(Exception xc) {
		err.println(header+":"+xc);
		xc.printStackTrace();
	    }
	    synchronized(lock) {
		done=true;
		try {
		    if((peer==null)||peer.isDone()) {
			// Cleanup if there is only one peer OR
			// if _both_ peers are done
			inSock.close();
			outSock.close();
		    }
		    else 
			// Signal the peer (if any) that we're done on this side of the connection
			peer.interrupt();
		} catch(Exception xc) {
		    err.println(header+":"+xc);
		    xc.printStackTrace();
		} finally {
		    connections.removeElement(this);
		}
	    }
	}
 
	public boolean isDone() {
	    return done;
	}
    
	public void setPeer(StreamCopyThread peer) {
	    this.peer=peer;
	}
     }

    // Holds all the currently active StreamCopyThreads
    private java.util.Vector connections=new java.util.Vector();
    // Used to synchronize the connection-handling threads with this thread
    private Object lock=new Object();
    // The address to forward connections to
    private InetAddress dstAddr;
    // The port to forward connections to
    private int dstPort;
    // Backlog parameter used when creating the ServerSocket
    protected static final int backLog=100;
    // Timeout waiting for a StreamCopyThread to finish
    public static final int threadTimeout=2000; //ms
    // Linger time
    public static final int lingerTime=180; //seconds (?)
    // Size of receive buffer
    public static final int bufSize=2048;
    // Header to prepend to log messages
    private String header;
    // This proxy's server socket
    private ServerSocket srvSock;
    // Debug flag
    private boolean debug=false;
 
    // Log streams for output and error messages
    private PrintStream out;
    private PrintStream err;
 
    private static final String 
	argsMessage="Arguments: ( [source_address] source_port dest_address dest_port ) | config_file";
    private static final String 
	propertyPrefix="proxy";

 
    public ProxyThread(InetAddress srcAddr,int srcPort,
		       InetAddress dstAddr,int dstPort, PrintStream out, PrintStream err) 
	throws IOException {
	this.out=out;
	this.err=err;
	this.srvSock=(srcAddr==null) ? new ServerSocket(srcPort,backLog) :  
	    new ServerSocket(srcPort,backLog,srcAddr);
	this.dstAddr=dstAddr;
	this.dstPort=dstPort;
	this.header=(srcAddr==null ? "" : srcAddr.toString())+":"+srcPort+" <-> "+dstAddr+":"+dstPort;
	start();
    }
 
    public void run() {
	out.println(header+" : starting");
	try {
	    while(!isInterrupted()) {
		Socket serverSocket=srvSock.accept();
		try {
		    serverSocket.setSoLinger(true,lingerTime);
		    Socket clientSocket=new Socket(dstAddr,dstPort);
		    clientSocket.setSoLinger(true,lingerTime);
		    StreamCopyThread sToC=new StreamCopyThread(serverSocket,clientSocket, true);
		    StreamCopyThread cToS=new StreamCopyThread(clientSocket,serverSocket, false);
		    sToC.setPeer(cToS);
		    cToS.setPeer(sToC);
		    synchronized(lock) {
			connections.addElement(cToS);
			connections.addElement(sToC);
			sToC.start();
			cToS.start();
		    }
		} catch(Exception xc) {
		    err.println(header+":"+xc);
		    if(debug)
			xc.printStackTrace();
		}
	    }
	    srvSock.close();
	} catch(IOException xc) {
	    err.println(header+":"+xc);
	    if(debug)
		xc.printStackTrace();
	} finally {
	    cleanup();
	    out.println(header+" : stopped");
	}
    }
 
     private void cleanup() {
	synchronized(lock) {
	    try {
		while(connections.size()>0) {
		    StreamCopyThread sct=(StreamCopyThread)connections.elementAt(0);
		    sct.interrupt();
		    sct.join(threadTimeout);
		}
	    } catch(InterruptedException xc) {
	    }
	}
    }
 
    private static ProxyThread addProxy(String src,String srcPort, String dst, String dstPort,
					PrintStream out, PrintStream err) throws
					UnknownHostException, IOException
    {
	InetAddress srcAddr=(src==null) ? null : InetAddress.getByName(src);
	return new ProxyThread(srcAddr,Integer.parseInt(srcPort),
			       InetAddress.getByName(dst),Integer.parseInt(dstPort),out,err);
    }
 
    private static java.util.Vector parseConfigFile(String fileName,PrintStream out,PrintStream err) throws 
        FileNotFoundException, IOException, UnknownHostException
    {
	java.util.Vector result=new java.util.Vector();
	FileInputStream in=new FileInputStream(fileName);
	java.util.Properties props= new java.util.Properties();
	props.load(in);
	in.close();
	for(int i=0;;i++) {
	    String srcAddr=props.getProperty(propertyPrefix+"."+i+".sourceAddr");
	    String srcPort=props.getProperty(propertyPrefix+"."+i+".sourcePort");
	    if(srcPort==null)
		break;
	    String dstAddr=props.getProperty(propertyPrefix+"."+i+".destAddr");
	    String dstPort=props.getProperty(propertyPrefix+"."+i+".destPort");
	    if(dstAddr==null) {
		throw new IllegalArgumentException("Missing destination address for proxy "+i);
	    }
	    if(dstPort==null) {
		throw new IllegalArgumentException("Missing destination port for proxy "+i);
	    }
	    result.addElement(addProxy(srcAddr,srcPort,dstAddr,dstPort,out,err));
	}
	return result;
    }
 
    static java.util.Vector parseArguments(String[] argv,PrintStream out,PrintStream err) throws
        FileNotFoundException, IOException, UnknownHostException
    {
	java.util.Vector result=null;
	int argBase=0;
	String src=null;
	if(argv.length>1) {
	    if(argv.length>3) {
		argBase=1;
		src=argv[0];
	    }
	    result=new java.util.Vector();
	    result.addElement(addProxy(src,argv[argBase++],argv[argBase++],argv[argBase++],out,err));
	} else if(argv.length==1) {
	    result=parseConfigFile(argv[0],out,err);
	} else {
	    throw new IllegalArgumentException(argsMessage);
	}
	return result;
    }
 
    public static void main(String[] argv) throws Exception {
	System.out.println("Java Transparent Proxy");
	System.out.println("Copyright (C) 1999 by Didier Frick (http://www.dfr.ch/)");
	System.out.println("This software is provided under the GNU general public license"+
			   " (http://www.gnu.org/copyleft/gpl.html)");
	try {
	    parseArguments(argv,System.out,System.err);
	} catch(IllegalArgumentException xc) {
	    System.err.println(xc.getMessage());
	    System.exit(1);
	}
    }
}
The initial ProxyThread code came from here: http://www.dfr.ch/en/proxy.html