summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Core/ConnectionManager.cpp39
-rw-r--r--src/Core/ConnectionManager.h9
-rw-r--r--src/Net/Connection.cpp37
-rw-r--r--src/Net/Connection.h17
-rw-r--r--src/mad-core.cpp3
5 files changed, 81 insertions, 24 deletions
diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp
index 3b99875..fe9f639 100644
--- a/src/Core/ConnectionManager.cpp
+++ b/src/Core/ConnectionManager.cpp
@@ -29,21 +29,48 @@ namespace Mad {
namespace Core {
void ConnectionManager::refreshPollfds() {
+ // TODO: refreshPollfds()
+ pollfds.clear();
+
+ for(std::list<Net::Listener*>::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) {
+ std::vector<struct pollfd> fds = (*listener)->getPollfds();
+ pollfds.insert(pollfds.end(), fds.begin(), fds.end());
+ }
+
+ for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con)
+ pollfds.push_back((*con)->getPollfd());
+
+ for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con)
+ pollfds.push_back((*con)->getPollfd());
}
-void ConnectionManager::daemonReceiveHandler(const Net::Connection*, const Net::Packet &packet) const {
+void ConnectionManager::daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) {
std::cout << "Received daemon packet:" << std::endl;
std::cout << " Type: " << packet.getType() << std::endl;
std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl;
std::cout << " Length: " << packet.getLength() << std::endl;
+
+ for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end(); ++con) {
+ if(*con == connection) {
+ (*con)->disconnect();
+ break;
+ }
+ }
}
-void ConnectionManager::clientReceiveHandler(const Net::Connection*, const Net::Packet &packet) const {
+void ConnectionManager::clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet) {
std::cout << "Received client packet:" << std::endl;
std::cout << " Type: " << packet.getType() << std::endl;
std::cout << " Request ID: 0x" << std::hex << std::uppercase << packet.getRequestId() << std::dec << std::endl;
std::cout << " Length: " << packet.getLength() << std::endl;
+
+ for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end(); ++con) {
+ if(*con == connection) {
+ (*con)->disconnect();
+ break;
+ }
+ }
}
ConnectionManager::ConnectionManager() {
@@ -64,12 +91,6 @@ ConnectionManager::~ConnectionManager() {
delete *con;
}
-void ConnectionManager::wait(int timeout) {
- // TODO: wait()
-
- usleep(timeout);
-}
-
void ConnectionManager::run() {
// TODO: Logging
@@ -109,6 +130,8 @@ void ConnectionManager::run() {
}
}
}
+
+ refreshPollfds();
}
}
diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h
index 49c9398..b68919c 100644
--- a/src/Core/ConnectionManager.h
+++ b/src/Core/ConnectionManager.h
@@ -50,14 +50,17 @@ class ConnectionManager {
void refreshPollfds();
- void daemonReceiveHandler(const Net::Connection*, const Net::Packet &packet) const;
- void clientReceiveHandler(const Net::Connection*, const Net::Packet &packet) const;
+ void daemonReceiveHandler(const Net::Connection *connection, const Net::Packet &packet);
+ void clientReceiveHandler(const Net::Connection *connection, const Net::Packet &packet);
public:
ConnectionManager();
virtual ~ConnectionManager();
- void wait(int timeout);
+ void wait(int timeout) {
+ poll(pollfds.data(), pollfds.size(), timeout);
+ }
+
void run();
};
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 4e3fee4..19e7bf1 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -25,15 +25,17 @@
namespace Mad {
namespace Net {
-void Connection::handshake() {
- state = HANDSHAKE;
+void Connection::doHandshake() {
+ if(state != HANDSHAKE)
+ return;
int ret = gnutls_handshake(session);
if(ret < 0) {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
- // Error... disconnect
+ // TODO: Error
+ disconnect();
return;
}
@@ -42,8 +44,14 @@ void Connection::handshake() {
}
void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) {
- if(length != sizeof(Packet::Data))
- return; // Error... disconnect?
+ if(state != PACKET_HEADER)
+ return;
+
+ if(length != sizeof(Packet::Data)) {
+ // TODO: Error
+ disconnect();
+ return;
+ }
header = *reinterpret_cast<const Packet::Data*>(data);
@@ -59,8 +67,14 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng
}
void Connection::packetDataReceiveHandler(const void *data, unsigned long length) {
- if(length != header.length)
- return; // Error... disconnect?
+ if(state != PACKET_DATA)
+ return;
+
+ if(length != header.length) {
+ // TODO: Error
+ disconnect();
+ return;
+ }
signal(this, Packet(header.type, header.requestId, data, length));
@@ -80,7 +94,8 @@ void Connection::doReceive() {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
- // Error... disconnect?
+ // TODO: Error
+ disconnect();
return;
}
@@ -126,7 +141,8 @@ void Connection::doSend() {
if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
return;
- // Error... disconnect?
+ // TODO: Error
+ disconnect();
return;
}
@@ -171,6 +187,9 @@ void Connection::disconnect() {
struct pollfd Connection::getPollfd() const {
struct pollfd fd = {sock, (receiveComplete() ? 0 : POLLIN) | (sendQueueEmpty() ? 0 : POLLOUT), 0};
+ if(state == HANDSHAKE)
+ fd.events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT);
+
return fd;
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index 0880036..e147ad2 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -54,6 +54,8 @@ class Connection {
sigc::signal<void,const Connection*,const Packet&> signal;
+ void doHandshake();
+
void packetHeaderReceiveHandler(const void *data, unsigned long length);
void packetDataReceiveHandler(const void *data, unsigned long length);
@@ -86,13 +88,24 @@ class Connection {
IPAddress *peer;
- void handshake();
+ void handshake() {
+ if(isConnected())
+ return;
+
+ state = HANDSHAKE;
+
+ doHandshake();
+ }
+
virtual void connectionHeader() = 0;
bool rawReceive(unsigned long length, const sigc::slot<void,const void*,unsigned long> &notify);
bool rawSend(const unsigned char *data, unsigned long length);
bool enterReceiveLoop() {
+ if(!isConnected())
+ return false;
+
state = PACKET_HEADER;
return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler));
@@ -137,7 +150,7 @@ class Connection {
void sendReceive() {
if(state == HANDSHAKE) {
- handshake();
+ doHandshake();
return;
}
diff --git a/src/mad-core.cpp b/src/mad-core.cpp
index ea0eded..461f595 100644
--- a/src/mad-core.cpp
+++ b/src/mad-core.cpp
@@ -20,7 +20,6 @@
#include "Net/Connection.h"
#include "Core/ConnectionManager.h"
-
int main() {
Mad::Net::Connection::init();
@@ -28,7 +27,7 @@ int main() {
while(true) {
connectionManager.run();
- connectionManager.wait(100000);
+ connectionManager.wait(10000);
}
Mad::Net::Connection::deinit();