diff options
Diffstat (limited to 'src/Net')
-rw-r--r-- | src/Net/ClientConnection.cpp | 26 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 69 | ||||
-rw-r--r-- | src/Net/Connection.h | 11 | ||||
-rw-r--r-- | src/Net/Packet.h | 2 |
4 files changed, 86 insertions, 22 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 66df2ea..ea2e10e 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -32,8 +32,6 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti if(connected) disconnect(); - peer = new IPAddress(address); - sock = socket(PF_INET, SOCK_STREAM, 0); if(sock < 0) throw ConnectionException("socket()", std::strerror(errno)); @@ -41,6 +39,10 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) throw ConnectionException("connect()", std::strerror(errno)); + peer = new IPAddress(address); + + connected = true; + gnutls_anon_allocate_client_credentials(&anoncred); gnutls_init(&session, GNUTLS_CLIENT); @@ -52,10 +54,11 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock)); int ret = gnutls_handshake(session); - if(ret < 0) + if(ret < 0) { + disconnect(); + throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret)); - - connected = true; + } } void ClientConnection::disconnect() { @@ -64,19 +67,14 @@ void ClientConnection::disconnect() { gnutls_bye(session, GNUTLS_SHUT_RDWR); - if(sock >= 0) { - shutdown(sock, SHUT_RDWR); - close(sock); - sock = -1; - } + shutdown(sock, SHUT_RDWR); + close(sock); + sock = -1; gnutls_deinit(session); gnutls_anon_free_client_credentials(anoncred); - if(peer) { - delete peer; - peer = 0; - } + delete peer; connected = false; } diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 34b6ed6..bc38ed2 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -36,17 +36,74 @@ bool Connection::send(const Packet &packet) { if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) continue; - else - return false; - } - else { - data += ret; - dataLength -= ret; + + return false; } + + data += ret; + dataLength -= ret; } return true; } +bool Connection::recieve() { + unsigned char *headerData = reinterpret_cast<unsigned char*>(&header); + ssize_t ret; + + if(!isConnected()) + return false; + + while(true) { + if(!gnutls_record_check_pending(getSession())) + return false; + + if(read < sizeof(Packet::Data)) { + ret = gnutls_record_recv(getSession(), headerData+read, sizeof(Packet::Data)-read); + + if(ret < 0) { + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + continue; + + return false; + } + + read += ret; + + if(read < sizeof(Packet::Data)) + continue; + + if(!header.length) { + Packet packet(header.type, header.requestId); + + return true; + } + + data = new unsigned char[header.length]; + } + + ret = gnutls_record_recv(getSession(), data+read-sizeof(Packet::Data), header.length+sizeof(Packet::Data)-read); + + if(ret < 0) { + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + continue; + + return false; + } + + read += ret; + + if(read < header.length+sizeof(Packet::Data)) + continue; + + Packet packet(header.type, header.requestId, data, header.length); + + delete [] data; + data = 0; + + return true; + } +} + } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 6e8d8de..1879f9e 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -21,6 +21,7 @@ #define MAD_NET_CONNECTION_H_ #include <gnutls/gnutls.h> +#include "Packet.h" namespace Mad { namespace Net { @@ -29,16 +30,24 @@ class IPAddress; class Packet; class Connection { + private: + Packet::Data header; + unsigned char *data; + + unsigned long read; + protected: virtual gnutls_session_t& getSession() = 0; public: - virtual ~Connection() {} + Connection() : data(0), read(0) {} + virtual ~Connection() {if(data) delete [] data;} virtual bool isConnected() const = 0; virtual const IPAddress* getPeer() const = 0; bool send(const Packet &packet); + bool recieve(); static void init() { gnutls_global_init(); diff --git a/src/Net/Packet.h b/src/Net/Packet.h index 28dcb86..32aba18 100644 --- a/src/Net/Packet.h +++ b/src/Net/Packet.h @@ -39,7 +39,7 @@ class Packet { Data *rawData; public: - Packet(unsigned short type, unsigned short requestId, void *data = NULL, unsigned long length = 0) { + Packet(unsigned short type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) { rawData = (Data*)std::malloc(sizeof(Data)+length); rawData->type = type; |