summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Core/ConnectionManager.cpp49
-rw-r--r--src/Core/ConnectionManager.h3
-rw-r--r--src/Net/Connection.cpp18
-rw-r--r--src/Net/Connection.h34
-rw-r--r--src/Net/Packet.h11
-rw-r--r--src/mad-core.cpp7
-rw-r--r--src/madc.cpp12
7 files changed, 84 insertions, 50 deletions
diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp
index b338f62..5ff5e29 100644
--- a/src/Core/ConnectionManager.cpp
+++ b/src/Core/ConnectionManager.cpp
@@ -29,8 +29,6 @@ namespace Mad {
namespace Core {
void ConnectionManager::refreshPollfds() {
- // TODO: refreshPollfds()
-
pollfds.clear();
pollfdMap.clear();
@@ -54,31 +52,26 @@ void ConnectionManager::refreshPollfds() {
}
}
-void ConnectionManager::daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) {
- std::cout << "Received daemon packet:" << std::endl;
- std::cout << " Type: " << packet.getType() << std::endl;
- std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl;
- std::cout << " Length: " << packet.getLength() << std::endl;
-
- for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) {
- if(*con == connection) {
- (*con)->disconnect();
+void ConnectionManager::receiveHandler(Net::Connection *connection, const Net::Packet &packet) {
+ switch(packet.getType()) {
+ case Net::Packet::TYPE_UNKNOWN:
break;
- }
- }
-}
-
-void ConnectionManager::clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) {
- std::cout << "Received client packet:" << std::endl;
- std::cout << " Type: " << packet.getType() << std::endl;
- std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl;
- std::cout << " Length: " << packet.getLength() << std::endl;
-
- for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) {
- if(*con == connection) {
- (*con)->disconnect();
+ case Net::Packet::TYPE_DEBUG:
+ std::cout << "Received debug packet." << std::endl;
+ std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl;
break;
- }
+ case Net::Packet::TYPE_PING:
+ connection->send(Net::Packet(Net::Packet::TYPE_PONG, packet.getRequestId(), packet.getData(), packet.getLength()));
+ break;
+ case Net::Packet::TYPE_PONG:
+ // TODO: Pong!
+ break;
+ case Net::Packet::TYPE_DISCONNECT_REQ:
+ connection->send(Net::Packet(Net::Packet::TYPE_DISCONNECT_REP, packet.getRequestId()));
+ connection->disconnect();
+ break;
+ case Net::Packet::TYPE_DISCONNECT_REP:
+ connection->disconnect();
}
}
@@ -117,6 +110,7 @@ void ConnectionManager::run() {
++con;
}
else {
+ std::cout << "A daemon connection was dropped." << std::endl;
delete *con;
daemonConnections.erase(con++);
}
@@ -134,6 +128,7 @@ void ConnectionManager::run() {
++con;
}
else {
+ std::cout << "A client connection was dropped." << std::endl;
delete *con;
clientConnections.erase(con++);
}
@@ -145,11 +140,11 @@ void ConnectionManager::run() {
while((con = (*listener)->getConnection(pollfdMap)) != 0) {
if(con->isDaemonConnection()) {
daemonConnections.push_back(con);
- con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::daemonReceiveHandler));
+ con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::receiveHandler));
}
else {
clientConnections.push_back(con);
- con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::clientReceiveHandler));
+ con->signalReceive().connect(sigc::mem_fun(this, &ConnectionManager::receiveHandler));
}
}
}
diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h
index cbfc812..608ec27 100644
--- a/src/Core/ConnectionManager.h
+++ b/src/Core/ConnectionManager.h
@@ -52,8 +52,7 @@ class ConnectionManager {
void refreshPollfds();
- void daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet);
- void clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet);
+ void receiveHandler(Net::Connection *connection, const Net::Packet &packet);
public:
ConnectionManager();
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 19e7bf1..d069fc9 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -22,6 +22,8 @@
#include <cstring>
#include <sys/socket.h>
+#include <iostream>
+
namespace Mad {
namespace Net {
@@ -49,14 +51,14 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng
if(length != sizeof(Packet::Data)) {
// TODO: Error
- disconnect();
+ doDisconnect();
return;
}
header = *reinterpret_cast<const Packet::Data*>(data);
if(header.length == 0) {
- signal(this, Packet(header.type, header.requestId));
+ signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId));
enterReceiveLoop();
}
@@ -72,11 +74,11 @@ void Connection::packetDataReceiveHandler(const void *data, unsigned long length
if(length != header.length) {
// TODO: Error
- disconnect();
+ doDisconnect();
return;
}
- signal(this, Packet(header.type, header.requestId, data, length));
+ signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId, data, length));
enterReceiveLoop();
}
@@ -95,7 +97,7 @@ void Connection::doReceive() {
return;
// TODO: Error
- disconnect();
+ doDisconnect();
return;
}
@@ -142,7 +144,7 @@ void Connection::doSend() {
return;
// TODO: Error
- disconnect();
+ doDisconnect();
return;
}
@@ -166,11 +168,11 @@ bool Connection::rawSend(const unsigned char *data, unsigned long length) {
return true;
}
-void Connection::disconnect() {
+void Connection::doDisconnect() {
if(!isConnected())
return;
- gnutls_bye(session, GNUTLS_SHUT_RDWR);
+ gnutls_bye(session, GNUTLS_SHUT_WR);
shutdown(sock, SHUT_RDWR);
close(sock);
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index adb677a..a3670b0 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -44,7 +44,7 @@ class Connection {
};
enum State {
- DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA
+ DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE
} state;
Transmission transR;
@@ -52,7 +52,7 @@ class Connection {
Packet::Data header;
- sigc::signal<void,const Connection*,const Packet&> signal;
+ sigc::signal<void,Connection*,const Packet&> signal;
void doHandshake();
@@ -62,6 +62,8 @@ class Connection {
void doReceive();
void doSend();
+ void doDisconnect();
+
bool receiveComplete() const {
return (transR.length == transR.transmitted);
}
@@ -103,7 +105,7 @@ class Connection {
bool rawSend(const unsigned char *data, unsigned long length);
bool enterReceiveLoop() {
- if(!isConnected())
+ if(!isConnected() || isDisconnecting())
return false;
state = PACKET_HEADER;
@@ -119,7 +121,7 @@ class Connection {
virtual ~Connection() {
if(isConnected())
- disconnect();
+ doDisconnect();
if(transR.data)
delete [] transR.data;
@@ -131,19 +133,30 @@ class Connection {
}
bool isConnected() const {return (state != DISCONNECTED);}
- bool isConnecting() {
+ bool isConnecting() const {
return (state == HANDSHAKE || state == CONNECTION_HEADER);
}
+ bool isDisconnecting() const {
+ return (state == DISCONNECT || state == BYE);
+ }
+
const IPAddress* getPeer() {return peer;}
int getSocket() const {return sock;}
- void disconnect();
+ void disconnect() {
+ if(isConnected() && !isDisconnecting()) {
+ state = DISCONNECT;
+
+ if(sendQueueEmpty())
+ doDisconnect();
+ }
+ }
struct pollfd getPollfd() const;
bool send(const Packet &packet) {
- if(!isConnected() || isConnecting())
+ if(!isConnected() || isConnecting() || isDisconnecting())
return false;
return rawSend(reinterpret_cast<const unsigned char*>(packet.getRawData()), packet.getRawDataLength());
@@ -151,7 +164,7 @@ class Connection {
void sendReceive(short events = POLLIN|POLLOUT) {
if(events & POLLHUP || events & POLLERR) {
- disconnect();
+ doDisconnect();
return;
}
@@ -165,11 +178,14 @@ class Connection {
if(events & POLLOUT)
doSend();
+
+ if(state == DISCONNECT && sendQueueEmpty())
+ doDisconnect();
}
bool sendQueueEmpty() const {return transS.empty();}
- sigc::signal<void,const Connection*,const Packet&> signalReceive() const {return signal;}
+ sigc::signal<void,Connection*,const Packet&> signalReceive() const {return signal;}
static void init() {
gnutls_global_init();
diff --git a/src/Net/Packet.h b/src/Net/Packet.h
index 32aba18..c2bd4d1 100644
--- a/src/Net/Packet.h
+++ b/src/Net/Packet.h
@@ -28,6 +28,11 @@ namespace Net {
class Packet {
public:
+ enum Type {
+ TYPE_UNKNOWN = 0x0000, TYPE_DEBUG = 0x0001, TYPE_PING = 0x0002, TYPE_PONG = 0x0003,
+ TYPE_DISCONNECT_REQ = 0x0010, TYPE_DISCONNECT_REP = 0x0011
+ };
+
struct Data {
unsigned short type;
unsigned short requestId;
@@ -39,7 +44,7 @@ class Packet {
Data *rawData;
public:
- Packet(unsigned short type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) {
+ Packet(Type type, unsigned short requestId, const void *data = NULL, unsigned long length = 0) {
rawData = (Data*)std::malloc(sizeof(Data)+length);
rawData->type = type;
@@ -71,8 +76,8 @@ class Packet {
std::free(rawData);
}
- unsigned short getType() const {
- return rawData->type;
+ Type getType() const {
+ return (Type)rawData->type;
}
unsigned short getRequestId() const {
diff --git a/src/mad-core.cpp b/src/mad-core.cpp
index c46b932..36e889b 100644
--- a/src/mad-core.cpp
+++ b/src/mad-core.cpp
@@ -19,8 +19,15 @@
#include "Net/Connection.h"
#include "Core/ConnectionManager.h"
+#include <signal.h>
int main() {
+ sigset_t signals;
+
+ sigemptyset(&signals);
+ sigaddset(&signals, SIGPIPE);
+ sigprocmask(SIG_BLOCK, &signals, 0);
+
Mad::Net::Connection::init();
Mad::Core::ConnectionManager connectionManager;
diff --git a/src/madc.cpp b/src/madc.cpp
index aaf7197..08fb3b3 100644
--- a/src/madc.cpp
+++ b/src/madc.cpp
@@ -37,7 +37,8 @@ int main() {
connection.sendReceive(fd.revents);
}
- connection.send(Mad::Net::Packet(0x0001, 0xABCD));
+ connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DEBUG, 0x1234));
+ connection.send(Mad::Net::Packet(Mad::Net::Packet::TYPE_DISCONNECT_REQ, 0xABCD));
while(!connection.sendQueueEmpty()) {
struct pollfd fd = connection.getPollfd();
@@ -45,6 +46,15 @@ int main() {
if(poll(&fd, 1, 10000) > 0)
connection.sendReceive(fd.revents);
}
+
+ connection.disconnect();
+
+ while(connection.isConnected()) {
+ struct pollfd fd = connection.getPollfd();
+
+ if(poll(&fd, 1, 10000) > 0)
+ connection.sendReceive(fd.revents);
+ }
}
catch(Mad::Net::Exception &e) {
std::cerr << "Connection error: " << e.what() << std::endl;