summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Net/ClientConnection.cpp76
-rw-r--r--src/Net/ClientConnection.h29
-rw-r--r--src/Net/Connection.cpp44
-rw-r--r--src/Net/Connection.h40
-rw-r--r--src/Net/ServerConnection.cpp58
-rw-r--r--src/Net/ServerConnection.h22
6 files changed, 116 insertions, 153 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp
index 1523b44..ca12f9f 100644
--- a/src/Net/ClientConnection.cpp
+++ b/src/Net/ClientConnection.cpp
@@ -22,18 +22,11 @@
#include <cstring>
#include <cerrno>
#include <sys/socket.h>
-#include <sys/select.h>
+#include <fcntl.h>
namespace Mad {
namespace Net {
-void ClientConnection::sendConnectionHeader(bool daemon) {
- ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1};
-
- rawSend(reinterpret_cast<unsigned char*>(&header), sizeof(header));
- rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler));
-}
-
void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) {
if(length != sizeof(ConnectionHeader))
// Error... disconnect
@@ -49,14 +42,20 @@ void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned
// Unsupported protocol... disconnect
return;
- connecting = false;
enterReceiveLoop();
}
-void ClientConnection::connect(const IPAddress &address, bool daemon) throw(ConnectionException) {
- const int kx_list[] = {GNUTLS_KX_ANON_DH, 0};
+void ClientConnection::connectionHeader() {
+ ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1};
- if(connected)
+ rawSend(reinterpret_cast<unsigned char*>(&header), sizeof(header));
+ rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler));
+}
+
+void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(ConnectionException) {
+ daemon = daemon0;
+
+ if(isConnected())
disconnect();
sock = socket(PF_INET, SOCK_STREAM, 0);
@@ -68,58 +67,33 @@ void ClientConnection::connect(const IPAddress &address, bool daemon) throw(Conn
if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) {
close(sock);
delete peer;
+ peer = 0;
throw ConnectionException("connect()", std::strerror(errno));
}
- connected = true;
- connecting = true;
-
- gnutls_init(&session, GNUTLS_CLIENT);
+ // Set non-blocking flag
+ int flags = fcntl(sock, F_GETFL, 0);
- gnutls_set_default_priority(session);
- gnutls_kx_set_priority(session, kx_list);
-
- gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
-
- gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock));
-
- int ret = gnutls_handshake(session);
- if(ret < 0) {
- disconnect();
+ if(flags < 0) {
+ close(sock);
- throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
+ throw ConnectionException("fcntl()", std::strerror(errno));
}
- sendConnectionHeader(daemon);
-}
-
-void ClientConnection::disconnect() {
- if(!connected)
- return;
-
- gnutls_bye(session, GNUTLS_SHUT_RDWR);
+ fcntl(sock, F_SETFL, flags | O_NONBLOCK);
- shutdown(sock, SHUT_RDWR);
- close(sock);
-
- gnutls_deinit(session);
+ gnutls_init(&session, GNUTLS_CLIENT);
- delete peer;
+ gnutls_set_default_priority(session);
- connected = false;
-}
-
-bool ClientConnection::dataPending() const {
- if(!connected)
- return false;
+ const int kx_list[] = {GNUTLS_KX_ANON_DH, 0};
+ gnutls_kx_set_priority(session, kx_list);
- fd_set fds;
- FD_ZERO(&fds);
- FD_SET(sock, &fds);
+ gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
- struct timeval timeout = {0, 0};
+ gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock));
- return (select(sock + 1, &fds, NULL, NULL, &timeout) == 1);
+ handshake();
}
}
diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h
index c738b1e..18b1a02 100644
--- a/src/Net/ClientConnection.h
+++ b/src/Net/ClientConnection.h
@@ -30,44 +30,25 @@ class IPAddress;
class ClientConnection : public Connection {
private:
- bool connected;
- IPAddress *peer;
-
- bool connecting;
-
- int sock;
- gnutls_session_t session;
gnutls_anon_client_credentials_t anoncred;
- void sendConnectionHeader(bool daemon);
+ bool daemon;
+
void connectionHeaderReceiveHandler(const void *data, unsigned long length);
protected:
- virtual gnutls_session_t& getSession() {
- return session;
- }
+ virtual void connectionHeader();
public:
- ClientConnection() : connected(false), connecting(false) {
+ ClientConnection() : daemon(0) {
gnutls_anon_allocate_client_credentials(&anoncred);
}
virtual ~ClientConnection() {
- if(connected)
- disconnect();
-
gnutls_anon_free_client_credentials(anoncred);
}
- void connect(const IPAddress &address, bool daemon = false) throw(ConnectionException);
- void disconnect();
-
- virtual bool dataPending() const;
-
- virtual bool isConnected() const {return connected;}
- virtual const IPAddress* getPeer() const {return peer;}
-
- virtual bool isConnecting() const {return connecting;}
+ void connect(const IPAddress &address, bool daemon0 = false) throw(ConnectionException);
};
}
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index eb7a55f..e7be313 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -18,11 +18,29 @@
*/
#include "Connection.h"
+#include "IPAddress.h"
#include <cstring>
+#include <sys/socket.h>
namespace Mad {
namespace Net {
+void Connection::handshake() {
+ state = HANDSHAKE;
+
+ int ret = gnutls_handshake(session);
+ if(ret < 0) {
+ if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+ return;
+
+ // Error... disconnect
+ return;
+ }
+
+ state = CONNECTION_HEADER;
+ connectionHeader();
+}
+
void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) {
if(length != sizeof(Packet::Data))
return; // Error... disconnect?
@@ -35,6 +53,7 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng
enterReceiveLoop();
}
else {
+ state = PACKET_DATA;
rawReceive(header.length, sigc::mem_fun(this, &Connection::packetDataReceiveHandler));
}
}
@@ -52,13 +71,10 @@ void Connection::doReceive() {
if(!isConnected())
return;
- if(!dataPending())
- return;
-
if(receiveComplete())
return;
- ssize_t ret = gnutls_record_recv(getSession(), transR.data+transR.transmitted, transR.length-transR.transmitted);
+ ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted);
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
@@ -103,7 +119,7 @@ void Connection::doSend() {
return;
while(!sendQueueEmpty()) {
- ssize_t ret = gnutls_record_send(getSession(), transS.front().data+transS.front().transmitted,
+ ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted,
transS.front().length-transS.front().transmitted);
if(ret < 0) {
@@ -134,5 +150,23 @@ bool Connection::rawSend(const unsigned char *data, unsigned long length) {
return true;
}
+void Connection::disconnect() {
+ if(!isConnected())
+ return;
+
+ gnutls_bye(session, GNUTLS_SHUT_RDWR);
+
+ shutdown(sock, SHUT_RDWR);
+ close(sock);
+
+ gnutls_deinit(session);
+
+ if(peer)
+ delete peer;
+ peer = 0;
+
+ state = DISCONNECTED;
+}
+
}
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 381ed17..a5bda5a 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -42,6 +42,10 @@ class Connection {
sigc::slot<void,const void*,unsigned long> notify;
};
+ enum State {
+ DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA
+ } state;
+
Transmission transR;
std::queue<Transmission> transS;
@@ -72,23 +76,33 @@ class Connection {
unsigned char protVerMax;
};
- virtual gnutls_session_t& getSession() = 0;
+ int sock;
+ gnutls_session_t session;
+
+ IPAddress *peer;
- bool rawReceive(unsigned long length,
- const sigc::slot<void,const void*,unsigned long> &notify);
+ void handshake();
+ virtual void connectionHeader() = 0;
+
+ bool rawReceive(unsigned long length, const sigc::slot<void,const void*,unsigned long> &notify);
bool rawSend(const unsigned char *data, unsigned long length);
bool enterReceiveLoop() {
+ state = PACKET_HEADER;
+
return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler));
}
public:
- Connection() {
+ Connection() : state(DISCONNECTED), peer(0) {
transR.length = transR.transmitted = 0;
transR.data = 0;
}
virtual ~Connection() {
+ if(isConnected())
+ disconnect();
+
if(transR.data)
delete [] transR.data;
@@ -96,24 +110,30 @@ class Connection {
delete [] transS.front().data;
transS.pop();
}
-
}
- virtual bool isConnected() const = 0;
- virtual const IPAddress* getPeer() const = 0;
+ bool isConnected() const {return (state != DISCONNECTED);}
+ bool isConnecting() {
+ return (state == HANDSHAKE || state == CONNECTION_HEADER);
+ }
- virtual bool isConnecting() const = 0;
+ const IPAddress* getPeer() {return peer;}
- virtual bool dataPending() const = 0;
+ void disconnect();
bool send(const Packet &packet) {
- if(isConnecting())
+ if(!isConnected() || isConnecting())
return false;
return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength());
}
void sendReceive() {
+ if(state == HANDSHAKE) {
+ handshake();
+ return;
+ }
+
doReceive();
doSend();
}
diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp
index dd25af1..4ad6215 100644
--- a/src/Net/ServerConnection.cpp
+++ b/src/Net/ServerConnection.cpp
@@ -22,7 +22,7 @@
#include <cstring>
#include <cerrno>
#include <sys/socket.h>
-#include <sys/select.h>
+#include <fcntl.h>
namespace Mad {
namespace Net {
@@ -54,14 +54,13 @@ void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned
rawSend(reinterpret_cast<unsigned char*>(&header2), sizeof(header2));
- connecting = false;
enterReceiveLoop();
}
void ServerConnection::listen(const IPAddress &address) throw(ConnectionException) {
const int kx_list[] = {GNUTLS_KX_ANON_DH, 0};
- if(connected)
+ if(isConnected())
disconnect();
int listen_sock = socket(PF_INET, SOCK_STREAM, 0);
@@ -97,10 +96,18 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio
close(listen_sock);
- *peer = IPAddress(address);
+ // Set non-blocking flag
+ int flags = fcntl(sock, F_GETFL, 0);
+
+ if(flags < 0) {
+ close(sock);
+
+ throw ConnectionException("fcntl()", std::strerror(errno));
+ }
+
+ fcntl(sock, F_SETFL, flags | O_NONBLOCK);
- connected = true;
- connecting = true;
+ *peer = IPAddress(address);
gnutls_init(&session, GNUTLS_SERVER);
@@ -110,45 +117,8 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio
gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred);
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock));
-
- int ret = gnutls_handshake(session);
- if(ret < 0) {
- disconnect();
-
- throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
- }
-
- rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ServerConnection::connectionHeaderReceiveHandler));
-}
-
-void ServerConnection::disconnect() {
- if(!connected)
- return;
-
- gnutls_bye(session, GNUTLS_SHUT_RDWR);
-
- shutdown(sock, SHUT_RDWR);
- close(sock);
-
- gnutls_deinit(session);
-
- delete peer;
-
- connected = false;
- connecting = false;
-}
-
-bool ServerConnection::dataPending() const {
- if(!connected)
- return false;
-
- fd_set fds;
- FD_ZERO(&fds);
- FD_SET(sock, &fds);
-
- struct timeval timeout = {0, 0};
- return (select(sock + 1, &fds, NULL, NULL, &timeout) == 1);
+ handshake();
}
}
diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h
index cd983e0..0fd5f0b 100644
--- a/src/Net/ServerConnection.h
+++ b/src/Net/ServerConnection.h
@@ -29,27 +29,22 @@ namespace Net {
class ServerConnection : public Connection {
private:
- bool connected;
IPAddress *peer;
- bool connecting;
-
bool daemon;
- int sock;
- gnutls_session_t session;
gnutls_anon_server_credentials_t anoncred;
gnutls_dh_params_t dh_params;
void connectionHeaderReceiveHandler(const void *data, unsigned long length);
protected:
- virtual gnutls_session_t& getSession() {
- return session;
+ virtual void connectionHeader() {
+ rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ServerConnection::connectionHeaderReceiveHandler));
}
public:
- ServerConnection() : connected(false), connecting(false), daemon(false) {
+ ServerConnection() : daemon(false) {
gnutls_anon_allocate_server_credentials(&anoncred);
gnutls_dh_params_init(&dh_params);
@@ -58,22 +53,11 @@ class ServerConnection : public Connection {
}
virtual ~ServerConnection() {
- if(connected)
- disconnect();
-
gnutls_dh_params_deinit(dh_params);
gnutls_anon_free_server_credentials(anoncred);
}
void listen(const IPAddress &address) throw(ConnectionException);
- void disconnect();
-
- virtual bool dataPending() const;
-
- virtual bool isConnected() const {return connected;}
- virtual const IPAddress* getPeer() const {return peer;}
-
- virtual bool isConnecting() const {return connecting;}
bool isDaemonConnection() const {return daemon;}
};