diff options
-rw-r--r-- | src/Net/ClientConnection.cpp | 4 | ||||
-rw-r--r-- | src/Net/ClientConnection.h | 2 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 128 | ||||
-rw-r--r-- | src/Net/Connection.h | 50 | ||||
-rw-r--r-- | src/Net/ServerConnection.cpp | 4 | ||||
-rw-r--r-- | src/Net/ServerConnection.h | 2 | ||||
-rw-r--r-- | src/mad-core.cpp | 7 |
7 files changed, 125 insertions, 72 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 1897b50..d49054d 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -62,6 +62,8 @@ void ClientConnection::connect(const IPAddress &address) throw(ConnectionExcepti throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret)); } + + enterReceiveLoop(); } void ClientConnection::disconnect() { @@ -80,7 +82,7 @@ void ClientConnection::disconnect() { connected = false; } -bool ClientConnection::dataPending() { +bool ClientConnection::dataPending() const { if(!connected) return false; diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index 72a0a92..e3797a5 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -57,7 +57,7 @@ class ClientConnection : public Connection { void connect(const IPAddress &address) throw(ConnectionException); void disconnect(); - virtual bool dataPending(); + virtual bool dataPending() const; virtual bool isConnected() const {return connected;} virtual const IPAddress* getPeer() const {return peer;} diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 31821ac..768d827 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -18,71 +18,88 @@ */ #include "Connection.h" -#include "Packet.h" namespace Mad { namespace Net { -bool Connection::send(const Packet &packet) { +void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { + if(length != sizeof(Packet::Data)) + return; // Error... disconnect? + + header = *reinterpret_cast<const Packet::Data*>(data); + + if(header.length == 0) { + signal(this, Packet(header.type, header.requestId)); + + enterReceiveLoop(); + } + else { + rawReceive(header.length, sigc::mem_fun(this, &Connection::packetDataReceiveHandler)); + } +} + +void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { + if(length != header.length) + return; // Error... disconnect? + + signal(this, Packet(header.type, header.requestId, data, length)); + + enterReceiveLoop(); +} + +void Connection::doReceive() { if(!isConnected()) - return false; + return; + + if(!dataPending()) + return; - const unsigned char *data = reinterpret_cast<const unsigned char*>(packet.getRawData()); - unsigned long dataLength = packet.getRawDataLength(); + if(receiveComplete()) + return; - while(dataLength > 0) { - ssize_t ret = gnutls_record_send(getSession(), data, dataLength); + ssize_t ret = gnutls_record_recv(getSession(), transR.data+transR.transmitted, transR.length-transR.transmitted); + + if(ret < 0) { + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + return; - if(ret < 0) { - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - continue; - - return false; - } + // Error... disconnect? + return; + } + + transR.transmitted += ret; + + if(receiveComplete()) { + transR.notify(transR.data, transR.length); - data += ret; - dataLength -= ret; + delete [] transR.data; + transR.data = 0; } +} + +bool Connection::rawReceive(unsigned long length, + const sigc::slot<void,const void*,unsigned long> ¬ify) +{ + if(!isConnected()) + return false; + + if(!receiveComplete()) + return false; + + transR.data = new unsigned char[length]; + transR.length = length; + transR.transmitted = 0; + transR.notify = notify; return true; } -bool Connection::receive() { - unsigned char *headerData = reinterpret_cast<unsigned char*>(&header); - ssize_t ret; - +bool Connection::rawSend(const unsigned char *data, unsigned long length) { if(!isConnected()) return false; - while(true) { - if(!dataPending()) - return false; - - if(read < sizeof(Packet::Data)) { - ret = gnutls_record_recv(getSession(), headerData+read, sizeof(Packet::Data)-read); - - if(ret < 0) { - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - continue; - - return false; - } - - read += ret; - - if(read < sizeof(Packet::Data)) - continue; - - if(!header.length) { - signal(this, Packet(header.type, header.requestId)); - - return true; - } - - data = new unsigned char[header.length]; - } - - ret = gnutls_record_recv(getSession(), data+read-sizeof(Packet::Data), header.length+sizeof(Packet::Data)-read); + while(length > 0) { + ssize_t ret = gnutls_record_send(getSession(), data, length); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) @@ -91,18 +108,11 @@ bool Connection::receive() { return false; } - read += ret; - - if(read < header.length+sizeof(Packet::Data)) - continue; - - signal(this, Packet(header.type, header.requestId, data, header.length)); - - delete [] data; - data = 0; - - return true; + data += ret; + length -= ret; } + + return true; } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 063ffb7..7e73955 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -20,6 +20,7 @@ #ifndef MAD_NET_CONNECTION_H_ #define MAD_NET_CONNECTION_H_ +#include <queue> #include <gnutls/gnutls.h> #include <sigc++/signal.h> #include "Packet.h" @@ -32,27 +33,60 @@ class Packet; class Connection { private: - Packet::Data header; - unsigned char *data; + struct Transmission { + unsigned long length; + unsigned long transmitted; + + unsigned char *data; + + sigc::slot<void,const void*,unsigned long> notify; + }; + + Transmission transR; - unsigned long read; + Packet::Data header; sigc::signal<void,const Connection*,const Packet&> signal; + void packetHeaderReceiveHandler(const void *data, unsigned long length); + void packetDataReceiveHandler(const void *data, unsigned long length); + + void doReceive(); + + bool receiveComplete() const { + return (transR.length == transR.transmitted); + } + protected: virtual gnutls_session_t& getSession() = 0; + bool rawReceive(unsigned long length, + const sigc::slot<void,const void*,unsigned long> ¬ify); + bool rawSend(const unsigned char *data, unsigned long length); + + bool enterReceiveLoop() { + return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); + } + public: - Connection() : data(0), read(0) {} - virtual ~Connection() {if(data) delete [] data;} + Connection() { + transR.length = transR.transmitted = 0; + } + + virtual ~Connection() {} virtual bool isConnected() const = 0; virtual const IPAddress* getPeer() const = 0; - virtual bool dataPending() = 0; + virtual bool dataPending() const = 0; - bool send(const Packet &packet); - bool receive(); + bool send(const Packet &packet) { + return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength()); + } + + void sendreceive() { + doReceive(); + } sigc::signal<void,const Connection*,const Packet&> signalReceive() const {return signal;} diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp index 7c9f9b8..66274b4 100644 --- a/src/Net/ServerConnection.cpp +++ b/src/Net/ServerConnection.cpp @@ -85,6 +85,8 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret)); } + + enterReceiveLoop(); } void ServerConnection::disconnect() { @@ -103,7 +105,7 @@ void ServerConnection::disconnect() { connected = false; } -bool ServerConnection::dataPending() { +bool ServerConnection::dataPending() const { if(!connected) return false; diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h index 09f9d6b..ebddf8a 100644 --- a/src/Net/ServerConnection.h +++ b/src/Net/ServerConnection.h @@ -62,7 +62,7 @@ class ServerConnection : public Connection { void listen(const IPAddress &address) throw(ConnectionException); void disconnect(); - virtual bool dataPending(); + virtual bool dataPending() const; virtual bool isConnected() const {return connected;} virtual const IPAddress* getPeer() const {return peer;} diff --git a/src/mad-core.cpp b/src/mad-core.cpp index 5ac0794..04b18cc 100644 --- a/src/mad-core.cpp +++ b/src/mad-core.cpp @@ -21,11 +21,16 @@ #include "Net/IPAddress.h" #include <iostream> + +bool running = true; + void receiveHandler(const Mad::Net::Connection*, const Mad::Net::Packet &packet) { std::cout << "Received packet:" << std::endl; std::cout << " Type: " << packet.getType() << std::endl; std::cout << " Request ID: " << packet.getRequestId() << std::endl; std::cout << " Length: " << packet.getLength() << std::endl; + + running = false; } int main() { @@ -37,7 +42,7 @@ int main() { try { connection.listen(Mad::Net::IPAddress("0.0.0.0", 6666)); - while(!connection.receive()); + while(running) connection.sendreceive(); } catch(Mad::Net::Exception &e) { std::cerr << "Connection error: " << e.what() << std::endl; |