summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Net/ClientConnection.cpp31
-rw-r--r--src/Net/ClientConnection.h11
-rw-r--r--src/Net/Connection.cpp9
-rw-r--r--src/Net/Connection.h17
-rw-r--r--src/Net/ServerConnection.cpp35
-rw-r--r--src/Net/ServerConnection.h12
-rw-r--r--src/madc.cpp4
7 files changed, 110 insertions, 9 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp
index d49054d..1523b44 100644
--- a/src/Net/ClientConnection.cpp
+++ b/src/Net/ClientConnection.cpp
@@ -27,7 +27,33 @@
namespace Mad {
namespace Net {
-void ClientConnection::connect(const IPAddress &address) throw(ConnectionException) {
+void ClientConnection::sendConnectionHeader(bool daemon) {
+ ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1};
+
+ rawSend(reinterpret_cast<unsigned char*>(&header), sizeof(header));
+ rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler));
+}
+
+void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) {
+ if(length != sizeof(ConnectionHeader))
+ // Error... disconnect
+ return;
+
+ const ConnectionHeader *header = reinterpret_cast<const ConnectionHeader*>(data);
+
+ if(header->m != 'M' || header->a != 'A' || header->d != 'D')
+ // Error... disconnect
+ return;
+
+ if(header->protVerMin != 1)
+ // Unsupported protocol... disconnect
+ return;
+
+ connecting = false;
+ enterReceiveLoop();
+}
+
+void ClientConnection::connect(const IPAddress &address, bool daemon) throw(ConnectionException) {
const int kx_list[] = {GNUTLS_KX_ANON_DH, 0};
if(connected)
@@ -46,6 +72,7 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
}
connected = true;
+ connecting = true;
gnutls_init(&session, GNUTLS_CLIENT);
@@ -63,7 +90,7 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
}
- enterReceiveLoop();
+ sendConnectionHeader(daemon);
}
void ClientConnection::disconnect() {
diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h
index e3797a5..c738b1e 100644
--- a/src/Net/ClientConnection.h
+++ b/src/Net/ClientConnection.h
@@ -33,17 +33,22 @@ class ClientConnection : public Connection {
bool connected;
IPAddress *peer;
+ bool connecting;
+
int sock;
gnutls_session_t session;
gnutls_anon_client_credentials_t anoncred;
+ void sendConnectionHeader(bool daemon);
+ void connectionHeaderReceiveHandler(const void *data, unsigned long length);
+
protected:
virtual gnutls_session_t& getSession() {
return session;
}
public:
- ClientConnection() : connected(false) {
+ ClientConnection() : connected(false), connecting(false) {
gnutls_anon_allocate_client_credentials(&anoncred);
}
@@ -54,13 +59,15 @@ class ClientConnection : public Connection {
gnutls_anon_free_client_credentials(anoncred);
}
- void connect(const IPAddress &address) throw(ConnectionException);
+ void connect(const IPAddress &address, bool daemon = false) throw(ConnectionException);
void disconnect();
virtual bool dataPending() const;
virtual bool isConnected() const {return connected;}
virtual const IPAddress* getPeer() const {return peer;}
+
+ virtual bool isConnecting() const {return connecting;}
};
}
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index b0f505b..eb7a55f 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -71,10 +71,13 @@ void Connection::doReceive() {
transR.transmitted += ret;
if(receiveComplete()) {
- transR.notify(transR.data, transR.length);
-
- delete [] transR.data;
+ // Save data pointer, as transR.notify might start a new reception
+ unsigned char *data = transR.data;
transR.data = 0;
+
+ transR.notify(data, transR.length);
+
+ delete [] data;
}
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 42e0c8e..381ed17 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -60,6 +60,18 @@ class Connection {
}
protected:
+ struct ConnectionHeader {
+ unsigned char m;
+ unsigned char a;
+ unsigned char d;
+ unsigned char type;
+
+ unsigned char versionMajor;
+ unsigned char versionMinor;
+ unsigned char protVerMin;
+ unsigned char protVerMax;
+ };
+
virtual gnutls_session_t& getSession() = 0;
bool rawReceive(unsigned long length,
@@ -90,9 +102,14 @@ class Connection {
virtual bool isConnected() const = 0;
virtual const IPAddress* getPeer() const = 0;
+ virtual bool isConnecting() const = 0;
+
virtual bool dataPending() const = 0;
bool send(const Packet &packet) {
+ if(isConnecting())
+ return false;
+
return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength());
}
diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp
index afbfcbb..dd25af1 100644
--- a/src/Net/ServerConnection.cpp
+++ b/src/Net/ServerConnection.cpp
@@ -27,6 +27,37 @@
namespace Mad {
namespace Net {
+void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) {
+ if(length != sizeof(ConnectionHeader))
+ // Error... disconnect
+ return;
+
+ const ConnectionHeader *header = reinterpret_cast<const ConnectionHeader*>(data);
+
+ if(header->m != 'M' || header->a != 'A' || header->d != 'D')
+ // Error... disconnect
+ return;
+
+ if(header->protVerMin > 1 || header->protVerMax < 1)
+ // Unsupported protocol... disconnect
+ return;
+
+ if(header->type == 'C')
+ daemon = false;
+ else if(header->type == 'D')
+ daemon = true;
+ else
+ // Error... disconnect
+ return;
+
+ ConnectionHeader header2 = {'M', 'A', 'D', 0, 0, 1, 1, 0};
+
+ rawSend(reinterpret_cast<unsigned char*>(&header2), sizeof(header2));
+
+ connecting = false;
+ enterReceiveLoop();
+}
+
void ServerConnection::listen(const IPAddress &address) throw(ConnectionException) {
const int kx_list[] = {GNUTLS_KX_ANON_DH, 0};
@@ -69,6 +100,7 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio
*peer = IPAddress(address);
connected = true;
+ connecting = true;
gnutls_init(&session, GNUTLS_SERVER);
@@ -86,7 +118,7 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio
throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
}
- enterReceiveLoop();
+ rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ServerConnection::connectionHeaderReceiveHandler));
}
void ServerConnection::disconnect() {
@@ -103,6 +135,7 @@ void ServerConnection::disconnect() {
delete peer;
connected = false;
+ connecting = false;
}
bool ServerConnection::dataPending() const {
diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h
index ebddf8a..cd983e0 100644
--- a/src/Net/ServerConnection.h
+++ b/src/Net/ServerConnection.h
@@ -32,18 +32,24 @@ class ServerConnection : public Connection {
bool connected;
IPAddress *peer;
+ bool connecting;
+
+ bool daemon;
+
int sock;
gnutls_session_t session;
gnutls_anon_server_credentials_t anoncred;
gnutls_dh_params_t dh_params;
+ void connectionHeaderReceiveHandler(const void *data, unsigned long length);
+
protected:
virtual gnutls_session_t& getSession() {
return session;
}
public:
- ServerConnection() : connected(false) {
+ ServerConnection() : connected(false), connecting(false), daemon(false) {
gnutls_anon_allocate_server_credentials(&anoncred);
gnutls_dh_params_init(&dh_params);
@@ -66,6 +72,10 @@ class ServerConnection : public Connection {
virtual bool isConnected() const {return connected;}
virtual const IPAddress* getPeer() const {return peer;}
+
+ virtual bool isConnecting() const {return connecting;}
+
+ bool isDaemonConnection() const {return daemon;}
};
}
diff --git a/src/madc.cpp b/src/madc.cpp
index 0c28793..762429e 100644
--- a/src/madc.cpp
+++ b/src/madc.cpp
@@ -29,6 +29,10 @@ int main() {
try {
connection.connect(Mad::Net::IPAddress("127.0.0.1", 6666));
+
+ while(connection.isConnecting())
+ connection.sendReceive();
+
connection.send(Mad::Net::Packet(0x0001, 0xABCD));
while(!connection.sendQueueEmpty())