summaryrefslogtreecommitdiffstats
path: root/src/Net
diff options
context:
space:
mode:
Diffstat (limited to 'src/Net')
-rw-r--r--src/Net/ClientConnection.cpp26
-rw-r--r--src/Net/Connection.cpp69
-rw-r--r--src/Net/Connection.h11
-rw-r--r--src/Net/Packet.h2
4 files changed, 86 insertions, 22 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp
index 66df2ea..ea2e10e 100644
--- a/src/Net/ClientConnection.cpp
+++ b/src/Net/ClientConnection.cpp
@@ -32,8 +32,6 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
if(connected)
disconnect();
- peer = new IPAddress(address);
-
sock = socket(PF_INET, SOCK_STREAM, 0);
if(sock < 0)
throw ConnectionException("socket()", std::strerror(errno));
@@ -41,6 +39,10 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0)
throw ConnectionException("connect()", std::strerror(errno));
+ peer = new IPAddress(address);
+
+ connected = true;
+
gnutls_anon_allocate_client_credentials(&anoncred);
gnutls_init(&session, GNUTLS_CLIENT);
@@ -52,10 +54,11 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock));
int ret = gnutls_handshake(session);
- if(ret < 0)
+ if(ret < 0) {
+ disconnect();
+
throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
-
- connected = true;
+ }
}
void ClientConnection::disconnect() {
@@ -64,19 +67,14 @@ void ClientConnection::disconnect() {
gnutls_bye(session, GNUTLS_SHUT_RDWR);
- if(sock >= 0) {
- shutdown(sock, SHUT_RDWR);
- close(sock);
- sock = -1;
- }
+ shutdown(sock, SHUT_RDWR);
+ close(sock);
+ sock = -1;
gnutls_deinit(session);
gnutls_anon_free_client_credentials(anoncred);
- if(peer) {
- delete peer;
- peer = 0;
- }
+ delete peer;
connected = false;
}
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 34b6ed6..bc38ed2 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -36,17 +36,74 @@ bool Connection::send(const Packet &packet) {
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
continue;
- else
- return false;
- }
- else {
- data += ret;
- dataLength -= ret;
+
+ return false;
}
+
+ data += ret;
+ dataLength -= ret;
}
return true;
}
+bool Connection::recieve() {
+ unsigned char *headerData = reinterpret_cast<unsigned char*>(&header);
+ ssize_t ret;
+
+ if(!isConnected())
+ return false;
+
+ while(true) {
+ if(!gnutls_record_check_pending(getSession()))
+ return false;
+
+ if(read < sizeof(Packet::Data)) {
+ ret = gnutls_record_recv(getSession(), headerData+read, sizeof(Packet::Data)-read);
+
+ if(ret < 0) {
+ if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+ continue;
+
+ return false;
+ }
+
+ read += ret;
+
+ if(read < sizeof(Packet::Data))
+ continue;
+
+ if(!header.length) {
+ Packet packet(header.type, header.requestId);
+
+ return true;
+ }
+
+ data = new unsigned char[header.length];
+ }
+
+ ret = gnutls_record_recv(getSession(), data+read-sizeof(Packet::Data), header.length+sizeof(Packet::Data)-read);
+
+ if(ret < 0) {
+ if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+ continue;
+
+ return false;
+ }
+
+ read += ret;
+
+ if(read < header.length+sizeof(Packet::Data))
+ continue;
+
+ Packet packet(header.type, header.requestId, data, header.length);
+
+ delete [] data;
+ data = 0;
+
+ return true;
+ }
+}
+
}
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 6e8d8de..1879f9e 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -21,6 +21,7 @@
#define MAD_NET_CONNECTION_H_
#include <gnutls/gnutls.h>
+#include "Packet.h"
namespace Mad {
namespace Net {
@@ -29,16 +30,24 @@ class IPAddress;
class Packet;
class Connection {
+ private:
+ Packet::Data header;
+ unsigned char *data;
+
+ unsigned long read;
+
protected:
virtual gnutls_session_t& getSession() = 0;
public:
- virtual ~Connection() {}
+ Connection() : data(0), read(0) {}
+ virtual ~Connection() {if(data) delete [] data;}
virtual bool isConnected() const = 0;
virtual const IPAddress* getPeer() const = 0;
bool send(const Packet &packet);
+ bool recieve();
static void init() {
gnutls_global_init();
diff --git a/src/Net/Packet.h b/src/Net/Packet.h
index 28dcb86..32aba18 100644
--- a/src/Net/Packet.h
+++ b/src/Net/Packet.h
@@ -39,7 +39,7 @@ class Packet {
Data *rawData;
public:
- Packet(unsigned short type, unsigned short requestId, void *data = NULL, unsigned long length = 0) {
+ Packet(unsigned short type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) {
rawData = (Data*)std::malloc(sizeof(Data)+length);
rawData->type = type;