summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Net/ClientConnection.cpp4
-rw-r--r--src/Net/ClientConnection.h2
-rw-r--r--src/Net/Connection.cpp128
-rw-r--r--src/Net/Connection.h50
-rw-r--r--src/Net/ServerConnection.cpp4
-rw-r--r--src/Net/ServerConnection.h2
-rw-r--r--src/mad-core.cpp7
7 files changed, 125 insertions, 72 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp
index 1897b50..d49054d 100644
--- a/src/Net/ClientConnection.cpp
+++ b/src/Net/ClientConnection.cpp
@@ -62,6 +62,8 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti
throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
}
+
+ enterReceiveLoop();
}
void ClientConnection::disconnect() {
@@ -80,7 +82,7 @@ void ClientConnection::disconnect() {
connected = false;
}
-bool ClientConnection::dataPending() {
+bool ClientConnection::dataPending() const {
if(!connected)
return false;
diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h
index 72a0a92..e3797a5 100644
--- a/src/Net/ClientConnection.h
+++ b/src/Net/ClientConnection.h
@@ -57,7 +57,7 @@ class ClientConnection : public Connection {
void connect(const IPAddress &address) throw(ConnectionException);
void disconnect();
- virtual bool dataPending();
+ virtual bool dataPending() const;
virtual bool isConnected() const {return connected;}
virtual const IPAddress* getPeer() const {return peer;}
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 31821ac..768d827 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -18,71 +18,88 @@
*/
#include "Connection.h"
-#include "Packet.h"
namespace Mad {
namespace Net {
-bool Connection::send(const Packet &packet) {
+void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) {
+ if(length != sizeof(Packet::Data))
+ return; // Error... disconnect?
+
+ header = *reinterpret_cast<const Packet::Data*>(data);
+
+ if(header.length == 0) {
+ signal(this, Packet(header.type, header.requestId));
+
+ enterReceiveLoop();
+ }
+ else {
+ rawReceive(header.length, sigc::mem_fun(this, &Connection::packetDataReceiveHandler));
+ }
+}
+
+void Connection::packetDataReceiveHandler(const void *data, unsigned long length) {
+ if(length != header.length)
+ return; // Error... disconnect?
+
+ signal(this, Packet(header.type, header.requestId, data, length));
+
+ enterReceiveLoop();
+}
+
+void Connection::doReceive() {
if(!isConnected())
- return false;
+ return;
+
+ if(!dataPending())
+ return;
- const unsigned char *data = reinterpret_cast<const unsigned char*>(packet.getRawData());
- unsigned long dataLength = packet.getRawDataLength();
+ if(receiveComplete())
+ return;
- while(dataLength > 0) {
- ssize_t ret = gnutls_record_send(getSession(), data, dataLength);
+ ssize_t ret = gnutls_record_recv(getSession(), transR.data+transR.transmitted, transR.length-transR.transmitted);
+
+ if(ret < 0) {
+ if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
+ return;
- if(ret < 0) {
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
- continue;
-
- return false;
- }
+ // Error... disconnect?
+ return;
+ }
+
+ transR.transmitted += ret;
+
+ if(receiveComplete()) {
+ transR.notify(transR.data, transR.length);
- data += ret;
- dataLength -= ret;
+ delete [] transR.data;
+ transR.data = 0;
}
+}
+
+bool Connection::rawReceive(unsigned long length,
+ const sigc::slot<void,const void*,unsigned long> &notify)
+{
+ if(!isConnected())
+ return false;
+
+ if(!receiveComplete())
+ return false;
+
+ transR.data = new unsigned char[length];
+ transR.length = length;
+ transR.transmitted = 0;
+ transR.notify = notify;
return true;
}
-bool Connection::receive() {
- unsigned char *headerData = reinterpret_cast<unsigned char*>(&header);
- ssize_t ret;
-
+bool Connection::rawSend(const unsigned char *data, unsigned long length) {
if(!isConnected())
return false;
- while(true) {
- if(!dataPending())
- return false;
-
- if(read < sizeof(Packet::Data)) {
- ret = gnutls_record_recv(getSession(), headerData+read, sizeof(Packet::Data)-read);
-
- if(ret < 0) {
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
- continue;
-
- return false;
- }
-
- read += ret;
-
- if(read < sizeof(Packet::Data))
- continue;
-
- if(!header.length) {
- signal(this, Packet(header.type, header.requestId));
-
- return true;
- }
-
- data = new unsigned char[header.length];
- }
-
- ret = gnutls_record_recv(getSession(), data+read-sizeof(Packet::Data), header.length+sizeof(Packet::Data)-read);
+ while(length > 0) {
+ ssize_t ret = gnutls_record_send(getSession(), data, length);
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
@@ -91,18 +108,11 @@ bool Connection::receive() {
return false;
}
- read += ret;
-
- if(read < header.length+sizeof(Packet::Data))
- continue;
-
- signal(this, Packet(header.type, header.requestId, data, header.length));
-
- delete [] data;
- data = 0;
-
- return true;
+ data += ret;
+ length -= ret;
}
+
+ return true;
}
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 063ffb7..7e73955 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -20,6 +20,7 @@
#ifndef MAD_NET_CONNECTION_H_
#define MAD_NET_CONNECTION_H_
+#include <queue>
#include <gnutls/gnutls.h>
#include <sigc++/signal.h>
#include "Packet.h"
@@ -32,27 +33,60 @@ class Packet;
class Connection {
private:
- Packet::Data header;
- unsigned char *data;
+ struct Transmission {
+ unsigned long length;
+ unsigned long transmitted;
+
+ unsigned char *data;
+
+ sigc::slot<void,const void*,unsigned long> notify;
+ };
+
+ Transmission transR;
- unsigned long read;
+ Packet::Data header;
sigc::signal<void,const Connection*,const Packet&> signal;
+ void packetHeaderReceiveHandler(const void *data, unsigned long length);
+ void packetDataReceiveHandler(const void *data, unsigned long length);
+
+ void doReceive();
+
+ bool receiveComplete() const {
+ return (transR.length == transR.transmitted);
+ }
+
protected:
virtual gnutls_session_t& getSession() = 0;
+ bool rawReceive(unsigned long length,
+ const sigc::slot<void,const void*,unsigned long> &notify);
+ bool rawSend(const unsigned char *data, unsigned long length);
+
+ bool enterReceiveLoop() {
+ return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler));
+ }
+
public:
- Connection() : data(0), read(0) {}
- virtual ~Connection() {if(data) delete [] data;}
+ Connection() {
+ transR.length = transR.transmitted = 0;
+ }
+
+ virtual ~Connection() {}
virtual bool isConnected() const = 0;
virtual const IPAddress* getPeer() const = 0;
- virtual bool dataPending() = 0;
+ virtual bool dataPending() const = 0;
- bool send(const Packet &packet);
- bool receive();
+ bool send(const Packet &packet) {
+ return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength());
+ }
+
+ void sendreceive() {
+ doReceive();
+ }
sigc::signal<void,const Connection*,const Packet&> signalReceive() const {return signal;}
diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp
index 7c9f9b8..66274b4 100644
--- a/src/Net/ServerConnection.cpp
+++ b/src/Net/ServerConnection.cpp
@@ -85,6 +85,8 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio
throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret));
}
+
+ enterReceiveLoop();
}
void ServerConnection::disconnect() {
@@ -103,7 +105,7 @@ void ServerConnection::disconnect() {
connected = false;
}
-bool ServerConnection::dataPending() {
+bool ServerConnection::dataPending() const {
if(!connected)
return false;
diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h
index 09f9d6b..ebddf8a 100644
--- a/src/Net/ServerConnection.h
+++ b/src/Net/ServerConnection.h
@@ -62,7 +62,7 @@ class ServerConnection : public Connection {
void listen(const IPAddress &address) throw(ConnectionException);
void disconnect();
- virtual bool dataPending();
+ virtual bool dataPending() const;
virtual bool isConnected() const {return connected;}
virtual const IPAddress* getPeer() const {return peer;}
diff --git a/src/mad-core.cpp b/src/mad-core.cpp
index 5ac0794..04b18cc 100644
--- a/src/mad-core.cpp
+++ b/src/mad-core.cpp
@@ -21,11 +21,16 @@
#include "Net/IPAddress.h"
#include <iostream>
+
+bool running = true;
+
void receiveHandler(const Mad::Net::Connection*, const Mad::Net::Packet &packet) {
std::cout << "Received packet:" << std::endl;
std::cout << " Type: " << packet.getType() << std::endl;
std::cout << " Request ID: " << packet.getRequestId() << std::endl;
std::cout << " Length: " << packet.getLength() << std::endl;
+
+ running = false;
}
int main() {
@@ -37,7 +42,7 @@ int main() {
try {
connection.listen(Mad::Net::IPAddress("0.0.0.0", 6666));
- while(!connection.receive());
+ while(running) connection.sendreceive();
}
catch(Mad::Net::Exception &e) {
std::cerr << "Connection error: " << e.what() << std::endl;