package p2pmpi.mpd;

import p2pmpi.message.*;
import p2pmpi.p2p.message.*;
import p2pmpi.common.*;
import p2pmpi.common.ProcessTable;
import p2pmpi.common.ProcessInfo;
import java.net.*;
import java.io.*;
import java.util.*;
import org.apache.log4j.*;

public class MPD_Interface extends Thread {
	P2PMPI_MPD parent;
	Logger log = null;
	protected Hashtable<String, AppInfo> hashTable;


	MPD_Interface(P2PMPI_MPD parent) {
		this.parent = parent;
		hashTable = new Hashtable<String, AppInfo>();
		log = Logger.getLogger("MPD");
	}

	// Recursive remove directory
	public boolean recursiveRemove(File path) {
		try {
			File[] files = path.listFiles();
			for(int i=0; i<files.length; i++) {
				if(files[i].isDirectory()) {
					recursiveRemove(files[i]);
				} else {
					if(!files[i].delete()) {
						log.error("Failed: removing " + files[i].toString() + "failed");
					} 
				}
			}
			if(!path.delete()) {
				log.error("Failed: removing " + path.toString() + "failed");
				return false;
			}
		} catch (Exception e) {
			log.debug(e.toString());
			return false;
		}

		return true;
	}

	// Start MPD_Interface
	public void run() {
		ServerSocket interfaceSocket = null;
		try {
			log.info("Listen MPD Socket at " + parent.mpdPort);
			interfaceSocket = new ServerSocket(parent.mpdPort);
		} catch (Exception e) {
			e.printStackTrace();
			return;
		}

		while(true) {
			try {
				Socket conn = interfaceSocket.accept();
				MPD_InterfaceThread msgMan = new MPD_InterfaceThread(interfaceSocket, conn);
				msgMan.start();
			} catch (Exception e) { e.printStackTrace(); break; }
			// Exception raise when serverSocket is closed
			System.gc();
		}

		try {
			interfaceSocket.close();
		} catch (Exception e) {e.printStackTrace();}

		// TODO: Add clean-up process ?
		//Real Exit
		return;
	}


	public class MPD_InterfaceThread extends Thread {
		ServerSocket serverSocket = null;
		Socket socket = null;
		Object oMsg;
		MPDMessage msg;
		AppRegisterMessage appRegMsg;
		AppInfo appInfo;

		MPD_InterfaceThread(ServerSocket serverSocket, Socket socket) {
			this.serverSocket 	= serverSocket;
			this.socket		= socket;
		}

		public void unRegisterRsService(String key) {
			StringTokenizer strToken = new StringTokenizer(key, "--");
			String hashID = strToken.nextToken();
			try {
				RSMessage rsMsg = new RSMessage(RSMessage.RS_CANCEL_RESERVATION_TO_RS, hashID, null, null);
				Socket socket = new Socket("127.0.0.1", parent.rsPort);
				OutputStream out = socket.getOutputStream();
				ObjectOutputStream oos = new ObjectOutputStream(out);
				oos.writeObject(rsMsg);
				oos.flush();
				oos.close();
				out.close();
				socket.close();
			} catch(Exception e) {
			}
		}

		public boolean isReserved(String id) {
			boolean reserved = false;
			try {
				RSMessage ckMsg = new RSMessage(RSMessage.MPD_UPDATE_STATUS_TO_RS, id);
				Socket sock = new Socket("127.0.0.1", parent.rsPort);
				InputStream in = sock.getInputStream();
				OutputStream out = sock.getOutputStream();
				ObjectOutputStream oos = new ObjectOutputStream(out);
				oos.writeObject(ckMsg);
				oos.flush();
				ObjectInputStream ois = new ObjectInputStream(in);
				ckMsg = (RSMessage)ois.readObject();
				reserved = ckMsg.isReserved();
				oos.close();
				ois.close();
				out.close();
				in.close();
				sock.close();
			} catch(Exception e) {
			}
			return reserved;
		}

		public void run() {
			InputStream in = null;
			OutputStream out = null;
			ObjectInputStream ois = null;
			ObjectOutputStream oos = null;
			try {
				in = socket.getInputStream();
				out = socket.getOutputStream();
				ois = new ObjectInputStream(in);
				oMsg = ois.readObject();
			} catch (Exception e) {
				return;
			}

			if(oMsg instanceof NotifyMessage) {
				//masque process as dead node
				NotifyMessage noMsg = (NotifyMessage)oMsg;
				log.debug("Got a notify message");
				String key = noMsg.getAppID();
				int deadNode = noMsg.getRankInList();
				log.debug("Key = " + key);
				log.debug("DEADNODE = " + deadNode);
				ProcessInfo procInfo = parent.procTab.getInfo(key);
				RankTable ranktb = procInfo.getRankTable();
				ranktb.setAlive(deadNode, false);
				log.info("MPD got notify message deadnode =" + deadNode);

			} else if(oMsg instanceof RequestQuitMessage) {
				log.info("----->> Request shutdown MPD ---------<<<");
				RequestQuitMessage quitMsg = new RequestQuitMessage();
				//send quit message to supernode
				UnregisterMessage unregisMsg = new UnregisterMessage(parent.myHost);
				try {
					log.info("Connect to " + parent.superNode.getHost() + ":" + parent.superNode.getPort());
					Socket s = new Socket(parent.superNode.getHost(), parent.superNode.getPort());
					out = s.getOutputStream();
					oos = new ObjectOutputStream(out);
					oos.writeObject(unregisMsg);
					oos.flush();
					oos.close();
					out.close();
					s.close();
				} catch (Exception e) {
					log.info("ERROR: " + e.toString());
				}
	
				//send RequestQuitMessage to all running mpi apps
				Enumeration mpiApps = parent.procTab.elements();
				while(mpiApps.hasMoreElements()) {
					ProcessInfo proInfo = (ProcessInfo)mpiApps.nextElement();
					int mpiPort   = proInfo.getPort();
					int mpiRank   = proInfo.getRank();
					String runDir = proInfo.getRunDir();

					// send quit message
					try {
						Socket s = new Socket("127.0.0.1", mpiPort);
						out = s.getOutputStream();
						oos = new ObjectOutputStream(out);
						oos.writeObject(quitMsg);
						oos.flush();
						oos.close();
						out.close();
						s.close();
					} catch (Exception e) {
					}
					// remove its running directory
					if(mpiRank != 0) {
						recursiveRemove(new File(runDir));
					}
				}

				System.gc();
				//Remove its running dir
				recursiveRemove(parent.runPath);
				//Exit
				System.exit(0);

			} else if(oMsg instanceof AppUnregisterMessage) {
				try {
					ois.close();
					in.close();
					socket.close();
				} catch (Exception e) {}

				AppUnregisterMessage unregMsg = (AppUnregisterMessage)oMsg;
				String key = unregMsg.getKey();
				String runDir = parent.procTab.getRunDir(key);
				int doneRank  = parent.procTab.getRank(key);

				log.debug("[MPD]: recieve unregister application of key = " + key);

				parent.procTab.removeProcess(key);
				//recursive remove tmpdir if rank is not 0
				if(doneRank != 0) {
					try {
						// Sleep 1 second to make sure application is well terminated ?
						Thread.sleep(1000); 
					} catch (Exception e) {}
					recursiveRemove(new File(runDir));
				}
				unRegisterRsService(key);

			} else if(oMsg instanceof AppRegisterMessage) {
				appRegMsg = (AppRegisterMessage)oMsg;
				ProcessInfo procInfo = new ProcessInfo();
				procInfo.setID(appRegMsg.getID());
				procInfo.setRunCmd(appRegMsg.getRunCmd());
				procInfo.setIPRank0(appRegMsg.getIpRank0());
				procInfo.setRunDir(appRegMsg.getRunDir());
				procInfo.setRank(appRegMsg.getRank());
				procInfo.setPort(appRegMsg.getPort());
				procInfo.setMPISize(appRegMsg.getMPISize());
				procInfo.setRealSize(appRegMsg.getRealSize());
				procInfo.setRankTable(appRegMsg.getRankTable());
			
				parent.procTab.addProcess(procInfo);

				//Send added message to prevent
				//BAD MPI application crashes and FD sends
				//unregister message before finish register application
				PingReplyMessage readyMsg = new PingReplyMessage();
				try {
					oos = new ObjectOutputStream(out);
					oos.writeObject(readyMsg);
					oos.flush();
					oos.close();
					out.close();
				} catch (Exception e) {
					e.printStackTrace();
				}

			} else if (oMsg instanceof StatQueryMessage) {
				// create a specific response message containing info from the disk cache
				StatInfoMessage statMsg = new StatInfoMessage( parent.usingIP, parent.mpdPort );
				Enumeration element = parent.procTab.elements();
				while(element.hasMoreElements()) {
					statMsg.addInfo((ProcessInfo)element.nextElement());
					//ProcessInfo procInfo = (ProcessInfo)element.nextElement();
					//if(procInfo.getRank() == 0) 
						//statMsg.addInfo(procInfo);
				}
				try {
					oos = new ObjectOutputStream(out);
					oos.writeObject(statMsg);
					oos.flush();
					oos.close();
					out.close();
				} catch (Exception e) { }

			} else if (oMsg instanceof RequestHostCacheMessage) {
				Vector<HostEntry> hc = new Vector<HostEntry>();
				//Sort HostCache by RTT
				ArrayList as = new ArrayList(parent.hostCache.entrySet());
				Collections.sort(as, new Comparator() {
					public int compare(Object o1, Object o2)
				{
					Map.Entry e1 = (Map.Entry)o1;
					Map.Entry e2 = (Map.Entry)o2;
					int first  = ((HostEntry)e1.getValue()).getRtt();
					int second = ((HostEntry)e2.getValue()).getRtt();
					return first - second;
				}
				});
				Iterator iter = as.iterator();
				while(iter.hasNext()) {
					hc.add((HostEntry)((Map.Entry)iter.next()).getValue());
				}
				ReplyHostCacheMessage hostCacheMsg = new ReplyHostCacheMessage(hc);
				try {
					oos = new ObjectOutputStream(out);
					oos.writeObject(hostCacheMsg);
					oos.flush();
					oos.close();
					out.close();
				} catch(Exception e) {}

			} else if (oMsg instanceof PingRequestMessage) {
				PingReplyMessage pingReplyMsg = new PingReplyMessage();
				try {
					oos = new ObjectOutputStream(out);
					oos.writeObject(pingReplyMsg);
					oos.flush();
					oos.close();
					out.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			} else if (oMsg instanceof MPDMessage) {
				msg = (MPDMessage)oMsg;
				String runCmd;

				switch(msg.getCmd()) {
				case MessageCmd.MPI_REQPEER:
					// MPD searchs the advertisements
					// and stores it in pipeAdvs
					int optionN = msg.getOptionN();
					int optionR = msg.getOptionR();
					int minRequire = optionR;
					int maxRequire = optionN * optionR;
					int waitTime      = msg.getWaitTime();
					URI mpiCom	  = msg.getIp();
					String hashID 	= msg.getID();
					runCmd   = msg.getRunCmd();
					String allocationMode = msg.getAllocationMode();

					ReservationResult rsResult = parent.searchMPD(hashID, optionN, optionR, waitTime, allocationMode);
					int resultStatus = rsResult.getStatus();
					FoundNodeMessage fnMsg;
					switch(resultStatus) {
						case ReservationResult.NOT_ENOUGH_NODE :
							log.info("[JOB] Can not Running " + (maxRequire+1) + " processes on " 
								+ rsResult.getNumNodeFound() + " machines");
							fnMsg = new FoundNodeMessage(rsResult.getNumNodeFound(),
									             rsResult.getNumSlotAvailable(),
										     false);
							try {
								oos = new ObjectOutputStream(out);
								oos.writeObject(fnMsg);
								oos.flush();
								oos.close();
								out.close();
							} catch (Exception e) {
								log.info(e.toString());
							}
						break;
						case ReservationResult.NOT_ENOUGH_SLOT :
							log.info("[JOB] Can not Running " + (maxRequire+1) + 
							" processes on " + rsResult.getNumNodeFound() + " machines that provide only " 
							+ (rsResult.getNumSlotAvailable()+1) + " slots");

							fnMsg = new FoundNodeMessage(rsResult.getNumNodeFound(),
									                              rsResult.getNumSlotAvailable()+1,
												      false);
							try {
								oos = new ObjectOutputStream(out);
								oos.writeObject(fnMsg);
								oos.flush();
								oos.close();
								out.close();
							} catch (Exception e) {
								log.info(e.toString());
							}

						break;
						case ReservationResult.SUCCESS :
							log.info("[JOB] Reservation successful");
							fnMsg = new FoundNodeMessage(rsResult.getNumNodeFound(),
									             rsResult.getNumSlotAvailable(),
										     true);
							try {
								oos = new ObjectOutputStream(out);
								oos.writeObject(fnMsg);
								oos.flush();
								oos.close();
								out.close();
							} catch (Exception e) {
								e.printStackTrace();
							}

							Vector<ReservedHost> hostList = rsResult.getReservedHosts();
							HostEntry myHost = rsResult.getMyHost();
							int hostListSize = hostList.size();

							MPDMessage mpdMsg;
							mpdMsg	= new MPDMessage(MessageCmd.MPD_REQPEER, hashID);
							mpdMsg.setIP(mpiCom.getHost(), mpiCom.getPort());
							mpdMsg.setRunCmd(runCmd);

							//Assign Rank
							int assignRank = 1;
							isReserved(hashID); //change status to running
							log.info("->> Try to assign rank to " + hostListSize + "nodes <---");
							for(int i = 0; i < hostListSize; i++) {
								HostEntry host = hostList.elementAt(i).getHostEntry();
								int assignNumProc = hostList.elementAt(i).getNumUsedSlot();
								if(host.getIp().equals(myHost.getIp())) {
									assignNumProc--;
								}
								for(int j = 0; j < assignNumProc; j++) {
									if(assignRank > optionN) {
										assignRank = 1;
									}
									log.info("Request assign rank " + assignRank + " to " + host.getIp() + ":" + host.getMpdPort());
									mpdMsg.setRank(assignRank);
									parent.requestPeer(host.getIp(), host.getMpdPort(), mpdMsg);
									assignRank++;
								}
							}
						break;
					}

					break;

				case MessageCmd.MPD_REQPEER :
					URI ip          = msg.getIp();
					String id       = msg.getID();
					runCmd          = msg.getRunCmd();
					int myRank      = msg.getRank();
					MPDMessage answerMsg = null;

					log.info("======= I have a peer request message =======");
					log.info(" ID     = " + id);
					log.info(" myRank = " + myRank);
					log.info(" URI    = " + ip.getHost() + ":" + ip.getPort());
					log.info(" runCmd = " + runCmd);
					log.info("==============================================");

					//Check if the key is reserved
					if(!isReserved(id)) {
						log.info("This request has never been reserved before, we can not let it run");
						break;
					}

					// Stock in an information table
					////////////////////////////////////
					appInfo = new AppInfo();
					appInfo.setRunCmd(runCmd);
					appInfo.setRank0(ip);
					String hashkey = id + "--" + myRank;
					hashTable.put(hashkey, appInfo);

					// Return result to rank0
					/////////////////////////////
					answerMsg = new MPDMessage(MessageCmd.MPD_ACCEPT, id);
					answerMsg.setFTPort(parent.ftPort);
					answerMsg.setFDPort(parent.fdPort);
					answerMsg.setRank(myRank);
					answerMsg.setMyHost(parent.usingIP);
					Socket conn = null;
					for(int i = 0; i < 10; i++) { //retry 10 times
						try {
							conn = new Socket(ip.getHost(), ip.getPort());
							out = conn.getOutputStream();
							oos = new ObjectOutputStream(out);
							oos.writeObject(answerMsg);
							oos.flush();
							oos.close();
							out.close();
							conn.close();
							break;
						}
						catch (Exception e) {
							e.printStackTrace();
						}
						try {
							Thread.sleep(500);
						} catch(Exception ee) {}
					}
					break;

				
				case MessageCmd.FT_DONE :
					String myKey = (msg.getID() + "--" + msg.getKey());
					String jarDep = msg.getJars();
					appInfo = hashTable.remove(myKey);
					URI rank0 = appInfo.getRank0();
					log.info("Running MPI application");
					log.info("Rank 0 : " + appInfo.getRank0().toString());
					log.info("RunCmd : " + appInfo.getRunCmd());
					// Spawn MPI application
					Process p;
					String osname = OsInfo.getName();
					try {
						String execCmd = null;
						String commonArgs= jarDep + " " + 
							  			rank0.getHost() + " " + 
										rank0.getPort() + " " + 
										msg.getID() + " " + 
										msg.getKey() + " " + 
										appInfo.getRunCmd();
					
						// TO	DO: consider alternative ways of spawning
					      // the job in the p2pclient script. 
						// Currently p2pclient executes "java prog".
						// Idea: use the submission command of the batch system
						// in use if configured.
						if(OsInfo.isLinux(osname) || OsInfo.isMacosx(osname)) {
							execCmd = "p2pclient " + commonArgs;
						} else if (OsInfo.isWindows(osname)) {
							execCmd = "p2pclient.bat " + commonArgs;
						} else {
							execCmd = "p2pclient " + commonArgs;
						}
						log.debug("ExeCmd = " + execCmd);
						log.debug("RunDir = " + msg.getRunDir());

						p = Runtime.getRuntime().exec(execCmd, null, msg.getRunDir());

						InputStream err = p.getErrorStream();
						InputStream inn = p.getInputStream();

						//Run Gobbler thread to take output
						//and send to rank 0
						StreamGobbler errGlobber = new StreamGobbler(err, true, rank0.getHost(), rank0.getPort());
						//TODO: now assume true = interactive mode
						errGlobber.start();
						StreamGobbler innGlobber = new StreamGobbler(inn, true, rank0.getHost(), rank0.getPort());
						innGlobber.start();
	
						p.waitFor();

						err.close();
						inn.close();
					} catch (Exception e) {
						e.printStackTrace();
					}
				break;

				default :
					log.debug("Unknown Command");
				break;
				}
			} else {
				log.debug("***************** Unknown  Messsage *****************");
			}
			try {
				ois.close();
				in.close();
				socket.close();
			} catch (Exception e) {}
			//Call garbage collection
			System.gc();
		}
	}


}
