summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/Net/Connection.cpp96
-rw-r--r--src/Net/Connection.h92
-rw-r--r--src/madc.cpp26
3 files changed, 104 insertions, 110 deletions
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 8af02c6..5d221fb 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -22,25 +22,23 @@
#include <cstring>
#include <sys/socket.h>
-#include <iostream>
-
namespace Mad {
namespace Net {
void Connection::doHandshake() {
if(state != HANDSHAKE)
return;
-
+
int ret = gnutls_handshake(session);
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
-
+
// TODO: Error
doDisconnect();
return;
}
-
+
state = CONNECTION_HEADER;
connectionHeader();
}
@@ -48,37 +46,35 @@ void Connection::doHandshake() {
void Connection::doBye() {
if(state != BYE)
return;
-
- int ret = gnutls_bye(session, GNUTLS_SHUT_WR);
+
+ int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
-
+
// TODO: Error
doDisconnect();
return;
}
-
- std::cout << "Bye!" << std::endl;
-
+
doDisconnect();
}
void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) {
if(state != PACKET_HEADER)
return;
-
+
if(length != sizeof(Packet::Data)) {
// TODO: Error
doDisconnect();
return;
}
-
+
header = *reinterpret_cast<const Packet::Data*>(data);
-
+
if(header.length == 0) {
signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId));
-
+
enterReceiveLoop();
}
else {
@@ -90,45 +86,45 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng
void Connection::packetDataReceiveHandler(const void *data, unsigned long length) {
if(state != PACKET_DATA)
return;
-
+
if(length != header.length) {
// TODO: Error
doDisconnect();
return;
}
-
+
signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId, data, length));
-
+
enterReceiveLoop();
}
void Connection::doReceive() {
if(!isConnected())
return;
-
+
if(receiveComplete())
return;
-
+
ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted);
-
+
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
-
+
// TODO: Error
doDisconnect();
return;
}
-
+
transR.transmitted += ret;
-
+
if(receiveComplete()) {
// Save data pointer, as transR.notify might start a new reception
unsigned char *data = transR.data;
transR.data = 0;
-
+
transR.notify(data, transR.length);
-
+
delete [] data;
}
}
@@ -138,37 +134,37 @@ bool Connection::rawReceive(unsigned long length,
{
if(!isConnected())
return false;
-
+
if(!receiveComplete())
return false;
-
+
transR.data = new unsigned char[length];
transR.length = length;
transR.transmitted = 0;
transR.notify = notify;
-
+
return true;
}
void Connection::doSend() {
if(!isConnected())
return;
-
+
while(!sendQueueEmpty()) {
- ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted,
+ ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted,
transS.front().length-transS.front().transmitted);
-
+
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
-
+
// TODO: Error
doDisconnect();
return;
}
-
+
transS.front().transmitted += ret;
-
+
if(transS.front().transmitted == transS.front().length) {
delete [] transS.front().data;
transS.pop();
@@ -179,11 +175,11 @@ void Connection::doSend() {
bool Connection::rawSend(const unsigned char *data, unsigned long length) {
if(!isConnected())
return false;
-
+
Transmission trans = {length, 0, new unsigned char[length], sigc::slot<void,const void*,unsigned long>()};
std::memcpy(trans.data, data, length);
transS.push(trans);
-
+
return true;
}
@@ -192,49 +188,49 @@ void Connection::sendReceive(short events) {
doDisconnect();
return;
}
-
+
if(state == HANDSHAKE) {
doHandshake();
return;
}
-
+
if(state == BYE) {
doBye();
return;
}
-
+
if(events & POLLIN)
doReceive();
-
+
if(events & POLLOUT)
doSend();
-
- if(state == DISCONNECT && sendQueueEmpty())
+
+ if(state == DISCONNECT && sendQueueEmpty())
bye();
}
void Connection::doDisconnect() {
if(!isConnected())
return;
-
+
shutdown(sock, SHUT_RDWR);
close(sock);
-
+
gnutls_deinit(session);
-
+
if(peer)
delete peer;
peer = 0;
-
+
state = DISCONNECTED;
}
struct pollfd Connection::getPollfd() const {
struct pollfd fd = {sock, (receiveComplete() ? 0 : POLLIN) | (sendQueueEmpty() ? 0 : POLLOUT), 0};
-
- if(state == HANDSHAKE)
+
+ if(state == HANDSHAKE || state == BYE)
fd.events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT);
-
+
return fd;
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 01926e1..0949ec4 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -37,152 +37,152 @@ class Connection {
struct Transmission {
unsigned long length;
unsigned long transmitted;
-
+
unsigned char *data;
-
+
sigc::slot<void,const void*,unsigned long> notify;
};
-
+
enum State {
DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE
} state;
-
+
Transmission transR;
std::queue<Transmission> transS;
-
+
Packet::Data header;
-
+
sigc::signal<void,Connection*,const Packet&> signal;
-
+
void doHandshake();
-
+
void packetHeaderReceiveHandler(const void *data, unsigned long length);
void packetDataReceiveHandler(const void *data, unsigned long length);
-
+
void doReceive();
void doSend();
-
+
void doBye();
-
+
void doDisconnect();
-
+
bool receiveComplete() const {
return (transR.length == transR.transmitted);
}
-
+
void bye() {
if(state != DISCONNECT)
return;
-
+
state = BYE;
-
+
doBye();
}
-
+
// Prevent shallow copy
Connection(const Connection &o);
Connection& operator=(const Connection &o);
-
+
protected:
struct ConnectionHeader {
unsigned char m;
unsigned char a;
unsigned char d;
unsigned char type;
-
+
unsigned char versionMajor;
unsigned char versionMinor;
unsigned char protVerMin;
unsigned char protVerMax;
};
-
+
int sock;
gnutls_session_t session;
-
+
IPAddress *peer;
-
+
void handshake() {
if(isConnected())
return;
-
+
state = HANDSHAKE;
-
+
doHandshake();
}
-
+
virtual void connectionHeader() = 0;
-
+
bool rawReceive(unsigned long length, const sigc::slot<void,const void*,unsigned long> &notify);
bool rawSend(const unsigned char *data, unsigned long length);
-
+
bool enterReceiveLoop() {
if(!isConnected() || isDisconnecting())
return false;
-
+
state = PACKET_HEADER;
-
+
return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler));
}
-
+
public:
Connection() : state(DISCONNECTED), peer(0) {
transR.length = transR.transmitted = 0;
transR.data = 0;
}
-
+
virtual ~Connection() {
if(isConnected())
doDisconnect();
-
+
if(transR.data)
delete [] transR.data;
-
+
while(!sendQueueEmpty()) {
delete [] transS.front().data;
transS.pop();
}
}
-
+
bool isConnected() const {return (state != DISCONNECTED);}
bool isConnecting() const {
return (state == HANDSHAKE || state == CONNECTION_HEADER);
}
-
+
bool isDisconnecting() const {
return (state == DISCONNECT || state == BYE);
}
-
+
const IPAddress* getPeer() {return peer;}
int getSocket() const {return sock;}
-
+
void disconnect() {
if(isConnected() && !isDisconnecting()) {
state = DISCONNECT;
-
+
if(sendQueueEmpty())
- doDisconnect();
+ bye();
}
}
-
+
struct pollfd getPollfd() const;
-
+
bool send(const Packet &packet) {
if(!isConnected() || isConnecting() || isDisconnecting())
return false;
-
+
return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength());
}
-
+
void sendReceive(short events = POLLIN|POLLOUT);
-
+
bool sendQueueEmpty() const {return transS.empty();}
-
+
sigc::signal<void,Connection*,const Packet&> signalReceive() const {return signal;}
-
+
static void init() {
gnutls_global_init();
}
-
+
static void deinit() {
gnutls_global_deinit();
}
diff --git a/src/madc.cpp b/src/madc.cpp
index 08fb3b3..790baf5 100644
--- a/src/madc.cpp
+++ b/src/madc.cpp
@@ -24,34 +24,34 @@
int main() {
Mad::Net::Connection::init();
-
+
Mad::Net::ClientConnection connection;
-
+
try {
connection.connect(Mad::Net::IPAddress("127.0.0.1", 6666));
-
+
while(connection.isConnecting()) {
struct pollfd fd = connection.getPollfd();
-
+
if(poll(&fd, 1, 10000) > 0)
connection.sendReceive(fd.revents);
}
-
+
connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DEBUG, 0x1234));
connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DISCONNECT_REQ, 0xABCD));
-
+
while(!connection.sendQueueEmpty()) {
struct pollfd fd = connection.getPollfd();
-
+
if(poll(&fd, 1, 10000) > 0)
connection.sendReceive(fd.revents);
}
-
+
connection.disconnect();
-
+
while(connection.isConnected()) {
struct pollfd fd = connection.getPollfd();
-
+
if(poll(&fd, 1, 10000) > 0)
connection.sendReceive(fd.revents);
}
@@ -59,10 +59,8 @@ int main() {
catch(Mad::Net::Exception &e) {
std::cerr << "Connection error: " << e.what() << std::endl;
}
-
- connection.disconnect();
-
+
Mad::Net::Connection::deinit();
-
+
return 0;
}