diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Core/ConnectionManager.cpp | 49 | ||||
-rw-r--r-- | src/Core/ConnectionManager.h | 3 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 18 | ||||
-rw-r--r-- | src/Net/Connection.h | 34 | ||||
-rw-r--r-- | src/Net/Packet.h | 11 | ||||
-rw-r--r-- | src/mad-core.cpp | 7 | ||||
-rw-r--r-- | src/madc.cpp | 12 |
7 files changed, 84 insertions, 50 deletions
diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp index b338f62..5ff5e29 100644 --- a/src/Core/ConnectionManager.cpp +++ b/src/Core/ConnectionManager.cpp @@ -29,8 +29,6 @@ namespace Mad { namespace Core { void ConnectionManager::refreshPollfds() { - // TODO: refreshPollfds() - pollfds.clear(); pollfdMap.clear(); @@ -54,31 +52,26 @@ void ConnectionManager::refreshPollfds() { } } -void ConnectionManager::daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) { - std::cout << "Received daemon packet:" << std::endl; - std::cout << " Type: " << packet.getType() << std::endl; - std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl; - std::cout << " Length: " << packet.getLength() << std::endl; - - for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) { - if(*con == connection) { - (*con)->disconnect(); +void ConnectionManager::receiveHandler(Net::Connection *connection, const Net::Packet &packet) { + switch(packet.getType()) { + case Net::Packet::TYPE_UNKNOWN: break; - } - } -} - -void ConnectionManager::clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) { - std::cout << "Received client packet:" << std::endl; - std::cout << " Type: " << packet.getType() << std::endl; - std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl; - std::cout << " Length: " << packet.getLength() << std::endl; - - for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) { - if(*con == connection) { - (*con)->disconnect(); + case Net::Packet::TYPE_DEBUG: + std::cout << "Received debug packet." << std::endl; + std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl; break; - } + case Net::Packet::TYPE_PING: + connection->send(Net::Packet(Net::Packet::TYPE_PONG, packet.getRequestId(), packet.getData(), packet.getLength())); + break; + case Net::Packet::TYPE_PONG: + // TODO: Pong! + break; + case Net::Packet::TYPE_DISCONNECT_REQ: + connection->send(Net::Packet(Net::Packet::TYPE_DISCONNECT_REP, packet.getRequestId())); + connection->disconnect(); + break; + case Net::Packet::TYPE_DISCONNECT_REP: + connection->disconnect(); } } @@ -117,6 +110,7 @@ void ConnectionManager::run() { ++con; } else { + std::cout << "A daemon connection was dropped." << std::endl; delete *con; daemonConnections.erase(con++); } @@ -134,6 +128,7 @@ void ConnectionManager::run() { ++con; } else { + std::cout << "A client connection was dropped." << std::endl; delete *con; clientConnections.erase(con++); } @@ -145,11 +140,11 @@ void ConnectionManager::run() { while((con = (*listener)->getConnection(pollfdMap)) != 0) { if(con->isDaemonConnection()) { daemonConnections.push_back(con); - con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::daemonReceiveHandler)); + con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::receiveHandler)); } else { clientConnections.push_back(con); - con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::clientReceiveHandler)); + con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::receiveHandler)); } } } diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h index cbfc812..608ec27 100644 --- a/src/Core/ConnectionManager.h +++ b/src/Core/ConnectionManager.h @@ -52,8 +52,7 @@ class ConnectionManager { void refreshPollfds(); - void daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet); - void clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet); + void receiveHandler(Net::Connection *connection, const Net::Packet &packet); public: ConnectionManager(); diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 19e7bf1..d069fc9 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -22,6 +22,8 @@ #include <cstring> #include <sys/socket.h> +#include <iostream> + namespace Mad { namespace Net { @@ -49,14 +51,14 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng if(length != sizeof(Packet::Data)) { // TODO: Error - disconnect(); + doDisconnect(); return; } header = *reinterpret_cast<const Packet::Data*>(data); if(header.length == 0) { - signal(this, Packet(header.type, header.requestId)); + signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId)); enterReceiveLoop(); } @@ -72,11 +74,11 @@ void Connection::packetDataReceiveHandler(const void *data, unsigned long length if(length != header.length) { // TODO: Error - disconnect(); + doDisconnect(); return; } - signal(this, Packet(header.type, header.requestId, data, length)); + signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId, data, length)); enterReceiveLoop(); } @@ -95,7 +97,7 @@ void Connection::doReceive() { return; // TODO: Error - disconnect(); + doDisconnect(); return; } @@ -142,7 +144,7 @@ void Connection::doSend() { return; // TODO: Error - disconnect(); + doDisconnect(); return; } @@ -166,11 +168,11 @@ bool Connection::rawSend(const unsigned char *data, unsigned long length) { return true; } -void Connection::disconnect() { +void Connection::doDisconnect() { if(!isConnected()) return; - gnutls_bye(session, GNUTLS_SHUT_RDWR); + gnutls_bye(session, GNUTLS_SHUT_WR); shutdown(sock, SHUT_RDWR); close(sock); diff --git a/src/Net/Connection.h b/src/Net/Connection.h index adb677a..a3670b0 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -44,7 +44,7 @@ class Connection { }; enum State { - DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA + DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE } state; Transmission transR; @@ -52,7 +52,7 @@ class Connection { Packet::Data header; - sigc::signal<void,const Connection*,const Packet&> signal; + sigc::signal<void,Connection*,const Packet&> signal; void doHandshake(); @@ -62,6 +62,8 @@ class Connection { void doReceive(); void doSend(); + void doDisconnect(); + bool receiveComplete() const { return (transR.length == transR.transmitted); } @@ -103,7 +105,7 @@ class Connection { bool rawSend(const unsigned char *data, unsigned long length); bool enterReceiveLoop() { - if(!isConnected()) + if(!isConnected() || isDisconnecting()) return false; state = PACKET_HEADER; @@ -119,7 +121,7 @@ class Connection { virtual ~Connection() { if(isConnected()) - disconnect(); + doDisconnect(); if(transR.data) delete [] transR.data; @@ -131,19 +133,30 @@ class Connection { } bool isConnected() const {return (state != DISCONNECTED);} - bool isConnecting() { + bool isConnecting() const { return (state == HANDSHAKE || state == CONNECTION_HEADER); } + bool isDisconnecting() const { + return (state == DISCONNECT || state == BYE); + } + const IPAddress* getPeer() {return peer;} int getSocket() const {return sock;} - void disconnect(); + void disconnect() { + if(isConnected() && !isDisconnecting()) { + state = DISCONNECT; + + if(sendQueueEmpty()) + doDisconnect(); + } + } struct pollfd getPollfd() const; bool send(const Packet &packet) { - if(!isConnected() || isConnecting()) + if(!isConnected() || isConnecting() || isDisconnecting()) return false; return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength()); @@ -151,7 +164,7 @@ class Connection { void sendReceive(short events = POLLIN|POLLOUT) { if(events & POLLHUP || events & POLLERR) { - disconnect(); + doDisconnect(); return; } @@ -165,11 +178,14 @@ class Connection { if(events & POLLOUT) doSend(); + + if(state == DISCONNECT && sendQueueEmpty()) + doDisconnect(); } bool sendQueueEmpty() const {return transS.empty();} - sigc::signal<void,const Connection*,const Packet&> signalReceive() const {return signal;} + sigc::signal<void,Connection*,const Packet&> signalReceive() const {return signal;} static void init() { gnutls_global_init(); diff --git a/src/Net/Packet.h b/src/Net/Packet.h index 32aba18..c2bd4d1 100644 --- a/src/Net/Packet.h +++ b/src/Net/Packet.h @@ -28,6 +28,11 @@ namespace Net { class Packet { public: + enum Type { + TYPE_UNKNOWN = 0x0000, TYPE_DEBUG = 0x0001, TYPE_PING = 0x0002, TYPE_PONG = 0x0003, + TYPE_DISCONNECT_REQ = 0x0010, TYPE_DISCONNECT_REP = 0x0011 + }; + struct Data { unsigned short type; unsigned short requestId; @@ -39,7 +44,7 @@ class Packet { Data *rawData; public: - Packet(unsigned short type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) { + Packet(Type type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) { rawData = (Data*)std::malloc(sizeof(Data)+length); rawData->type = type; @@ -71,8 +76,8 @@ class Packet { std::free(rawData); } - unsigned short getType() const { - return rawData->type; + Type getType() const { + return (Type)rawData->type; } unsigned short getRequestId() const { diff --git a/src/mad-core.cpp b/src/mad-core.cpp index c46b932..36e889b 100644 --- a/src/mad-core.cpp +++ b/src/mad-core.cpp @@ -19,8 +19,15 @@ #include "Net/Connection.h" #include "Core/ConnectionManager.h" +#include <signal.h> int main() { + sigset_t signals; + + sigemptyset(&signals); + sigaddset(&signals, SIGPIPE); + sigprocmask(SIG_BLOCK, &signals, 0); + Mad::Net::Connection::init(); Mad::Core::ConnectionManager connectionManager; diff --git a/src/madc.cpp b/src/madc.cpp index aaf7197..08fb3b3 100644 --- a/src/madc.cpp +++ b/src/madc.cpp @@ -37,7 +37,8 @@ int main() { connection.sendReceive(fd.revents); } - connection.send(Mad::Net::Packet(0x0001, 0xABCD)); + connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DEBUG, 0x1234)); + connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DISCONNECT_REQ, 0xABCD)); while(!connection.sendQueueEmpty()) { struct pollfd fd = connection.getPollfd(); @@ -45,6 +46,15 @@ int main() { if(poll(&fd, 1, 10000) > 0) connection.sendReceive(fd.revents); } + + connection.disconnect(); + + while(connection.isConnected()) { + struct pollfd fd = connection.getPollfd(); + + if(poll(&fd, 1, 10000) > 0) + connection.sendReceive(fd.revents); + } } catch(Mad::Net::Exception &e) { std::cerr << "Connection error: " << e.what() << std::endl; |