diff options
-rw-r--r-- | src/Net/Connection.cpp | 96 | ||||
-rw-r--r-- | src/Net/Connection.h | 92 | ||||
-rw-r--r-- | src/madc.cpp | 26 |
3 files changed, 104 insertions, 110 deletions
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 8af02c6..5d221fb 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -22,25 +22,23 @@ #include <cstring> #include <sys/socket.h> -#include <iostream> - namespace Mad { namespace Net { void Connection::doHandshake() { if(state != HANDSHAKE) return; - + int ret = gnutls_handshake(session); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - + // TODO: Error doDisconnect(); return; } - + state = CONNECTION_HEADER; connectionHeader(); } @@ -48,37 +46,35 @@ void Connection::doHandshake() { void Connection::doBye() { if(state != BYE) return; - - int ret = gnutls_bye(session, GNUTLS_SHUT_WR); + + int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - + // TODO: Error doDisconnect(); return; } - - std::cout << "Bye!" << std::endl; - + doDisconnect(); } void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { if(state != PACKET_HEADER) return; - + if(length != sizeof(Packet::Data)) { // TODO: Error doDisconnect(); return; } - + header = *reinterpret_cast<const Packet::Data*>(data); - + if(header.length == 0) { signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId)); - + enterReceiveLoop(); } else { @@ -90,45 +86,45 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { if(state != PACKET_DATA) return; - + if(length != header.length) { // TODO: Error doDisconnect(); return; } - + signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId, data, length)); - + enterReceiveLoop(); } void Connection::doReceive() { if(!isConnected()) return; - + if(receiveComplete()) return; - + ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); - + if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - + // TODO: Error doDisconnect(); return; } - + transR.transmitted += ret; - + if(receiveComplete()) { // Save data pointer, as transR.notify might start a new reception unsigned char *data = transR.data; transR.data = 0; - + transR.notify(data, transR.length); - + delete [] data; } } @@ -138,37 +134,37 @@ bool Connection::rawReceive(unsigned long length, { if(!isConnected()) return false; - + if(!receiveComplete()) return false; - + transR.data = new unsigned char[length]; transR.length = length; transR.transmitted = 0; transR.notify = notify; - + return true; } void Connection::doSend() { if(!isConnected()) return; - + while(!sendQueueEmpty()) { - ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, + ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, transS.front().length-transS.front().transmitted); - + if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; - + // TODO: Error doDisconnect(); return; } - + transS.front().transmitted += ret; - + if(transS.front().transmitted == transS.front().length) { delete [] transS.front().data; transS.pop(); @@ -179,11 +175,11 @@ void Connection::doSend() { bool Connection::rawSend(const unsigned char *data, unsigned long length) { if(!isConnected()) return false; - + Transmission trans = {length, 0, new unsigned char[length], sigc::slot<void,const void*,unsigned long>()}; std::memcpy(trans.data, data, length); transS.push(trans); - + return true; } @@ -192,49 +188,49 @@ void Connection::sendReceive(short events) { doDisconnect(); return; } - + if(state == HANDSHAKE) { doHandshake(); return; } - + if(state == BYE) { doBye(); return; } - + if(events & POLLIN) doReceive(); - + if(events & POLLOUT) doSend(); - - if(state == DISCONNECT && sendQueueEmpty()) + + if(state == DISCONNECT && sendQueueEmpty()) bye(); } void Connection::doDisconnect() { if(!isConnected()) return; - + shutdown(sock, SHUT_RDWR); close(sock); - + gnutls_deinit(session); - + if(peer) delete peer; peer = 0; - + state = DISCONNECTED; } struct pollfd Connection::getPollfd() const { struct pollfd fd = {sock, (receiveComplete() ? 0 : POLLIN) | (sendQueueEmpty() ? 0 : POLLOUT), 0}; - - if(state == HANDSHAKE) + + if(state == HANDSHAKE || state == BYE) fd.events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); - + return fd; } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 01926e1..0949ec4 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -37,152 +37,152 @@ class Connection { struct Transmission { unsigned long length; unsigned long transmitted; - + unsigned char *data; - + sigc::slot<void,const void*,unsigned long> notify; }; - + enum State { DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE } state; - + Transmission transR; std::queue<Transmission> transS; - + Packet::Data header; - + sigc::signal<void,Connection*,const Packet&> signal; - + void doHandshake(); - + void packetHeaderReceiveHandler(const void *data, unsigned long length); void packetDataReceiveHandler(const void *data, unsigned long length); - + void doReceive(); void doSend(); - + void doBye(); - + void doDisconnect(); - + bool receiveComplete() const { return (transR.length == transR.transmitted); } - + void bye() { if(state != DISCONNECT) return; - + state = BYE; - + doBye(); } - + // Prevent shallow copy Connection(const Connection &o); Connection& operator=(const Connection &o); - + protected: struct ConnectionHeader { unsigned char m; unsigned char a; unsigned char d; unsigned char type; - + unsigned char versionMajor; unsigned char versionMinor; unsigned char protVerMin; unsigned char protVerMax; }; - + int sock; gnutls_session_t session; - + IPAddress *peer; - + void handshake() { if(isConnected()) return; - + state = HANDSHAKE; - + doHandshake(); } - + virtual void connectionHeader() = 0; - + bool rawReceive(unsigned long length, const sigc::slot<void,const void*,unsigned long> ¬ify); bool rawSend(const unsigned char *data, unsigned long length); - + bool enterReceiveLoop() { if(!isConnected() || isDisconnecting()) return false; - + state = PACKET_HEADER; - + return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); } - + public: Connection() : state(DISCONNECTED), peer(0) { transR.length = transR.transmitted = 0; transR.data = 0; } - + virtual ~Connection() { if(isConnected()) doDisconnect(); - + if(transR.data) delete [] transR.data; - + while(!sendQueueEmpty()) { delete [] transS.front().data; transS.pop(); } } - + bool isConnected() const {return (state != DISCONNECTED);} bool isConnecting() const { return (state == HANDSHAKE || state == CONNECTION_HEADER); } - + bool isDisconnecting() const { return (state == DISCONNECT || state == BYE); } - + const IPAddress* getPeer() {return peer;} int getSocket() const {return sock;} - + void disconnect() { if(isConnected() && !isDisconnecting()) { state = DISCONNECT; - + if(sendQueueEmpty()) - doDisconnect(); + bye(); } } - + struct pollfd getPollfd() const; - + bool send(const Packet &packet) { if(!isConnected() || isConnecting() || isDisconnecting()) return false; - + return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength()); } - + void sendReceive(short events = POLLIN|POLLOUT); - + bool sendQueueEmpty() const {return transS.empty();} - + sigc::signal<void,Connection*,const Packet&> signalReceive() const {return signal;} - + static void init() { gnutls_global_init(); } - + static void deinit() { gnutls_global_deinit(); } diff --git a/src/madc.cpp b/src/madc.cpp index 08fb3b3..790baf5 100644 --- a/src/madc.cpp +++ b/src/madc.cpp @@ -24,34 +24,34 @@ int main() { Mad::Net::Connection::init(); - + Mad::Net::ClientConnection connection; - + try { connection.connect(Mad::Net::IPAddress("127.0.0.1", 6666)); - + while(connection.isConnecting()) { struct pollfd fd = connection.getPollfd(); - + if(poll(&fd, 1, 10000) > 0) connection.sendReceive(fd.revents); } - + connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DEBUG, 0x1234)); connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DISCONNECT_REQ, 0xABCD)); - + while(!connection.sendQueueEmpty()) { struct pollfd fd = connection.getPollfd(); - + if(poll(&fd, 1, 10000) > 0) connection.sendReceive(fd.revents); } - + connection.disconnect(); - + while(connection.isConnected()) { struct pollfd fd = connection.getPollfd(); - + if(poll(&fd, 1, 10000) > 0) connection.sendReceive(fd.revents); } @@ -59,10 +59,8 @@ int main() { catch(Mad::Net::Exception &e) { std::cerr << "Connection error: " << e.what() << std::endl; } - - connection.disconnect(); - + Mad::Net::Connection::deinit(); - + return 0; } |