diff options
-rw-r--r-- | src/Core/ConnectionManager.cpp | 33 | ||||
-rw-r--r-- | src/Core/ConnectionManager.h | 2 | ||||
-rw-r--r-- | src/Net/Connection.h | 15 | ||||
-rw-r--r-- | src/Net/Listener.cpp | 13 | ||||
-rw-r--r-- | src/Net/Listener.h | 3 | ||||
-rw-r--r-- | src/madc.cpp | 16 |
6 files changed, 65 insertions, 17 deletions
diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp index fe9f639..ba78a98 100644 --- a/src/Core/ConnectionManager.cpp +++ b/src/Core/ConnectionManager.cpp @@ -32,17 +32,26 @@ void ConnectionManager::refreshPollfds() { // TODO: refreshPollfds() pollfds.clear(); + pollfdMap.clear(); for(std::list<Net::Listener*>::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) { std::vector<struct pollfd> fds = (*listener)->getPollfds(); - pollfds.insert(pollfds.end(), fds.begin(), fds.end()); + + for(std::vector<struct pollfd>::iterator fd = fds.begin(); fd != fds.end(); ++fd) { + pollfds.push_back(*fd); + pollfdMap.insert(std::make_pair(fd->fd, &pollfds.back().revents)); + } } - for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) + for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) { pollfds.push_back((*con)->getPollfd()); + pollfdMap.insert(std::make_pair(pollfds.back().fd, &pollfds.back().revents)); + } - for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) + for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) { pollfds.push_back((*con)->getPollfd()); + pollfdMap.insert(std::make_pair(pollfds.back().fd, &pollfds.back().revents)); + } } void ConnectionManager::daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) { @@ -96,7 +105,13 @@ void ConnectionManager::run() { for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end();) { if((*con)->isConnected()) { - (*con)->sendReceive(); + std::map<int,const short*>::iterator events = pollfdMap.find((*con)->getSocket()); + + if(events != pollfdMap.end()) + (*con)->sendReceive(*events->second); + else + (*con)->sendReceive(); + ++con; } else { @@ -107,7 +122,13 @@ void ConnectionManager::run() { for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end();) { if((*con)->isConnected()) { - (*con)->sendReceive(); + std::map<int,const short*>::iterator events = pollfdMap.find((*con)->getSocket()); + + if(events != pollfdMap.end()) + (*con)->sendReceive(*events->second); + else + (*con)->sendReceive(); + ++con; } else { @@ -119,7 +140,7 @@ void ConnectionManager::run() { for(std::list<Net::Listener*>::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) { Net::ServerConnection *con; - while((con = (*listener)->getConnection()) != 0) { + while((con = (*listener)->getConnection(pollfdMap)) != 0) { if(con->isDaemonConnection()) { daemonConnections.push_back(con); con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::daemonReceiveHandler)); diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h index b68919c..b252910 100644 --- a/src/Core/ConnectionManager.h +++ b/src/Core/ConnectionManager.h @@ -22,6 +22,7 @@ #include <list> #include <vector> +#include <map> #include <poll.h> namespace Mad { @@ -47,6 +48,7 @@ class ConnectionManager { std::list<Net::ServerConnection*> clientConnections; std::vector<struct pollfd> pollfds; + std::map<int,const short*> pollfdMap; void refreshPollfds(); diff --git a/src/Net/Connection.h b/src/Net/Connection.h index e147ad2..adb677a 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -136,6 +136,7 @@ class Connection { } const IPAddress* getPeer() {return peer;} + int getSocket() const {return sock;} void disconnect(); @@ -148,14 +149,22 @@ class Connection { return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength()); } - void sendReceive() { + void sendReceive(short events = POLLIN|POLLOUT) { + if(events & POLLHUP || events & POLLERR) { + disconnect(); + return; + } + if(state == HANDSHAKE) { doHandshake(); return; } - doReceive(); - doSend(); + if(events & POLLIN) + doReceive(); + + if(events & POLLOUT) + doSend(); } bool sendQueueEmpty() const {return transS.empty();} diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 8386389..981b3c7 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -90,21 +90,28 @@ std::vector<struct pollfd> Listener::getPollfds() const { return pollfds; } -ServerConnection* Listener::getConnection() { +ServerConnection* Listener::getConnection(const std::map<int,const short*> &pollfdMap) { // TODO: Logging int sd; struct sockaddr_in sa; socklen_t addrlen = sizeof(sa); + while((sd = accept(sock, reinterpret_cast<struct sockaddr*>(&sa), &addrlen)) >= 0) { connections.push_back(new ServerConnection(sd, IPAddress(sa), dh_params)); addrlen = sizeof(sa); } - for(std::list<ServerConnection*>::iterator con = connections.begin(); con != connections.end(); ++con) - (*con)->sendReceive(); + for(std::list<ServerConnection*>::iterator con = connections.begin(); con != connections.end(); ++con) { + std::map<int,const short*>::const_iterator events = pollfdMap.find((*con)->getSocket()); + + if(events != pollfdMap.end()) + (*con)->sendReceive(*events->second); + else + (*con)->sendReceive(); + } for(std::list<ServerConnection*>::iterator con = connections.begin(); con != connections.end();) { if(!(*con)->isConnected()) { diff --git a/src/Net/Listener.h b/src/Net/Listener.h index 9952268..30fde63 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -26,6 +26,7 @@ #include <poll.h> #include <list> #include <vector> +#include <map> namespace Mad { namespace Net { @@ -51,7 +52,7 @@ class Listener { std::vector<struct pollfd> getPollfds() const; - ServerConnection* getConnection(); + ServerConnection* getConnection(const std::map<int,const short*> &pollfdMap); }; } diff --git a/src/madc.cpp b/src/madc.cpp index 762429e..8104d0b 100644 --- a/src/madc.cpp +++ b/src/madc.cpp @@ -30,13 +30,21 @@ int main() { try { connection.connect(Mad::Net::IPAddress("127.0.0.1", 6666)); - while(connection.isConnecting()) - connection.sendReceive(); + while(connection.isConnecting()) { + struct pollfd fd = connection.getPollfd(); + + poll(&fd, 1, 10000); + connection.sendReceive(fd.revents); + } connection.send(Mad::Net::Packet(0x0001, 0xABCD)); - while(!connection.sendQueueEmpty()) - connection.sendReceive(); + while(!connection.sendQueueEmpty()) { + struct pollfd fd = connection.getPollfd(); + + poll(&fd, 1, 10000); + connection.sendReceive(fd.revents); + } } catch(Mad::Net::Exception &e) { std::cerr << "Connection error: " << e.what() << std::endl; |