diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Core/ConnectionManager.cpp | 39 | ||||
-rw-r--r-- | src/Core/ConnectionManager.h | 9 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 37 | ||||
-rw-r--r-- | src/Net/Connection.h | 17 | ||||
-rw-r--r-- | src/mad-core.cpp | 3 |
5 files changed, 81 insertions, 24 deletions
diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp index 3b99875..fe9f639 100644 --- a/src/Core/ConnectionManager.cpp +++ b/src/Core/ConnectionManager.cpp @@ -29,21 +29,48 @@ namespace Mad { namespace Core { void ConnectionManager::refreshPollfds() { + // TODO: refreshPollfds() + pollfds.clear(); + + for(std::list<Net::Listener*>::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) { + std::vector<struct pollfd> fds = (*listener)->getPollfds(); + pollfds.insert(pollfds.end(), fds.begin(), fds.end()); + } + + for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) + pollfds.push_back((*con)->getPollfd()); + + for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) + pollfds.push_back((*con)->getPollfd()); } -void ConnectionManager::daemonReceiveHandler(const Net::Connection*, const Net::Packet &packet) const { +void ConnectionManager::daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) { std::cout << "Received daemon packet:" << std::endl; std::cout << " Type: " << packet.getType() << std::endl; std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl; std::cout << " Length: " << packet.getLength() << std::endl; + + for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) { + if(*con == connection) { + (*con)->disconnect(); + break; + } + } } -void ConnectionManager::clientReceiveHandler(const Net::Connection*, const Net::Packet &packet) const { +void ConnectionManager::clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) { std::cout << "Received client packet:" << std::endl; std::cout << " Type: " << packet.getType() << std::endl; std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl; std::cout << " Length: " << packet.getLength() << std::endl; + + for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) { + if(*con == connection) { + (*con)->disconnect(); + break; + } + } } ConnectionManager::ConnectionManager() { @@ -64,12 +91,6 @@ ConnectionManager::~ConnectionManager() { delete *con; } -void ConnectionManager::wait(int timeout) { - // TODO: wait() - - usleep(timeout); -} - void ConnectionManager::run() { // TODO: Logging @@ -109,6 +130,8 @@ void ConnectionManager::run() { } } } + + refreshPollfds(); } } diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h index 49c9398..b68919c 100644 --- a/src/Core/ConnectionManager.h +++ b/src/Core/ConnectionManager.h @@ -50,14 +50,17 @@ class ConnectionManager { void refreshPollfds(); - void daemonReceiveHandler(const Net::Connection*, const Net::Packet &packet) const; - void clientReceiveHandler(const Net::Connection*, const Net::Packet &packet) const; + void daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet); + void clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet); public: ConnectionManager(); virtual ~ConnectionManager(); - void wait(int timeout); + void wait(int timeout) { + poll(pollfds.data(), pollfds.size(), timeout); + } + void run(); }; diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 4e3fee4..19e7bf1 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -25,15 +25,17 @@ namespace Mad { namespace Net { -void Connection::handshake() { - state = HANDSHAKE; +void Connection::doHandshake() { + if(state != HANDSHAKE) + return; int ret = gnutls_handshake(session); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - // Error... disconnect + // TODO: Error + disconnect(); return; } @@ -42,8 +44,14 @@ void Connection::handshake() { } void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { - if(length != sizeof(Packet::Data)) - return; // Error... disconnect? + if(state != PACKET_HEADER) + return; + + if(length != sizeof(Packet::Data)) { + // TODO: Error + disconnect(); + return; + } header = *reinterpret_cast<const Packet::Data*>(data); @@ -59,8 +67,14 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng } void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { - if(length != header.length) - return; // Error... disconnect? + if(state != PACKET_DATA) + return; + + if(length != header.length) { + // TODO: Error + disconnect(); + return; + } signal(this, Packet(header.type, header.requestId, data, length)); @@ -80,7 +94,8 @@ void Connection::doReceive() { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - // Error... disconnect? + // TODO: Error + disconnect(); return; } @@ -126,7 +141,8 @@ void Connection::doSend() { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - // Error... disconnect? + // TODO: Error + disconnect(); return; } @@ -171,6 +187,9 @@ void Connection::disconnect() { struct pollfd Connection::getPollfd() const { struct pollfd fd = {sock, (receiveComplete() ? 0 : POLLIN) | (sendQueueEmpty() ? 0 : POLLOUT), 0}; + if(state == HANDSHAKE) + fd.events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); + return fd; } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 0880036..e147ad2 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -54,6 +54,8 @@ class Connection { sigc::signal<void,const Connection*,const Packet&> signal; + void doHandshake(); + void packetHeaderReceiveHandler(const void *data, unsigned long length); void packetDataReceiveHandler(const void *data, unsigned long length); @@ -86,13 +88,24 @@ class Connection { IPAddress *peer; - void handshake(); + void handshake() { + if(isConnected()) + return; + + state = HANDSHAKE; + + doHandshake(); + } + virtual void connectionHeader() = 0; bool rawReceive(unsigned long length, const sigc::slot<void,const void*,unsigned long> ¬ify); bool rawSend(const unsigned char *data, unsigned long length); bool enterReceiveLoop() { + if(!isConnected()) + return false; + state = PACKET_HEADER; return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); @@ -137,7 +150,7 @@ class Connection { void sendReceive() { if(state == HANDSHAKE) { - handshake(); + doHandshake(); return; } diff --git a/src/mad-core.cpp b/src/mad-core.cpp index ea0eded..461f595 100644 --- a/src/mad-core.cpp +++ b/src/mad-core.cpp @@ -20,7 +20,6 @@ #include "Net/Connection.h" #include "Core/ConnectionManager.h" - int main() { Mad::Net::Connection::init(); @@ -28,7 +27,7 @@ int main() { while(true) { connectionManager.run(); - connectionManager.wait(100000); + connectionManager.wait(10000); } Mad::Net::Connection::deinit(); |