summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Common/ActionManager.cpp87
-rw-r--r--src/Common/ActionManager.h61
-rw-r--r--src/Common/CMakeLists.txt2
-rw-r--r--src/Common/ClientConnection.cpp8
-rw-r--r--src/Common/ClientConnection.h9
-rw-r--r--src/Common/Connection.h4
-rw-r--r--src/Common/Requests/CMakeLists.txt2
-rw-r--r--src/Common/Requests/IdentifyRequest.cpp (renamed from src/Daemon/Requests/IdentifyRequest.cpp)8
-rw-r--r--src/Common/Requests/IdentifyRequest.h (renamed from src/Daemon/Requests/IdentifyRequest.h)7
-rw-r--r--src/Daemon/Requests/CMakeLists.txt2
-rw-r--r--src/Net/CMakeLists.txt6
-rw-r--r--src/Net/ClientConnection.cpp87
-rw-r--r--src/Net/ClientConnection.h23
-rw-r--r--src/Net/Connection.cpp364
-rw-r--r--src/Net/Connection.h144
-rw-r--r--src/Net/Exception.cpp3
-rw-r--r--src/Net/Exception.h2
-rw-r--r--src/Net/FdManager.cpp174
-rw-r--r--src/Net/FdManager.h77
-rw-r--r--src/Net/IPAddress.cpp92
-rw-r--r--src/Net/IPAddress.h58
-rw-r--r--src/Net/Listener.cpp95
-rw-r--r--src/Net/Listener.h41
-rw-r--r--src/Net/ServerConnection.cpp90
-rw-r--r--src/Net/ServerConnection.h57
-rw-r--r--src/Net/ThreadManager.cpp14
-rw-r--r--src/Net/ThreadManager.h6
-rw-r--r--src/Server/ConnectionManager.cpp88
-rw-r--r--src/Server/ConnectionManager.h54
-rw-r--r--src/Server/RequestHandlers/CMakeLists.txt2
-rw-r--r--src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp4
-rw-r--r--src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp4
-rw-r--r--src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp4
-rw-r--r--src/Server/RequestHandlers/IdentifyRequestHandler.cpp5
-rw-r--r--src/mad-server.conf2
-rw-r--r--src/mad-server.cpp8
-rw-r--r--src/mad.cpp9
-rw-r--r--src/madc.cpp75
38 files changed, 403 insertions, 1375 deletions
diff --git a/src/Common/ActionManager.cpp b/src/Common/ActionManager.cpp
deleted file mode 100644
index fc3c034..0000000
--- a/src/Common/ActionManager.cpp
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * ActionManager.cpp
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "ActionManager.h"
-#include <Net/FdManager.h>
-
-#include <fcntl.h>
-#include <signal.h>
-
-namespace Mad {
-namespace Common {
-
-ActionManager ActionManager::actionManager;
-
-
-void ActionManager::doInit() {
- // TODO Error handling
-
- pipe(notifyPipe);
-
- fcntl(notifyPipe[0], F_SETFL, fcntl(notifyPipe[0], F_GETFL) | O_NONBLOCK);
- fcntl(notifyPipe[1], F_SETFL, fcntl(notifyPipe[1], F_GETFL) | O_NONBLOCK);
-
- Net::FdManager::get()->registerFd(notifyPipe[0], boost::bind(&ActionManager::run, this), POLLIN);
-}
-
-void ActionManager::doDeinit() {
- Net::FdManager::get()->unregisterFd(notifyPipe[0]);
-
- close(notifyPipe[0]);
- close(notifyPipe[1]);
-}
-
-
-void ActionManager::run() {
- // Empty pipe
- char buf[16];
- while(read(notifyPipe[0], buf, sizeof(buf)) > 0) {}
-
- while(true) {
- sigset_t set, oldset;
- sigfillset(&set);
- sigprocmask(SIG_SETMASK, &set, &oldset);
-
- if(actions.empty()) {
- sigprocmask(SIG_SETMASK, &oldset, 0);
- return;
- }
-
- boost::function0<void> action = actions.front();
- actions.pop();
-
- sigprocmask(SIG_SETMASK, &oldset, 0);
-
- action();
- }
-}
-
-void ActionManager::add(const boost::function0<void> &action) {
- sigset_t set, oldset;
- sigfillset(&set);
- sigprocmask(SIG_SETMASK, &set, &oldset);
-
- actions.push(action);
- write(notifyPipe[1], "", 1);
-
- sigprocmask(SIG_SETMASK, &oldset, 0);
-}
-
-}
-}
diff --git a/src/Common/ActionManager.h b/src/Common/ActionManager.h
deleted file mode 100644
index 5d3dc15..0000000
--- a/src/Common/ActionManager.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * ActionManager.h
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef MAD_COMMON_ACTIONMANAGER_H_
-#define MAD_COMMON_ACTIONMANAGER_H_
-
-#include "Initializable.h"
-
-#include <queue>
-#include <unistd.h>
-
-#include <boost/function.hpp>
-
-namespace Mad {
-namespace Common {
-
-class ActionManager : public Initializable {
- private:
- std::queue<boost::function0<void> > actions;
- int notifyPipe[2];
-
- static ActionManager actionManager;
-
- ActionManager() {}
-
- protected:
- void doInit();
- void doDeinit();
-
- public:
- void run();
- void add(const boost::function0<void> &action);
-
- static ActionManager *get() {
- if(!actionManager.isInitialized())
- actionManager.init();
-
- return &actionManager;
- }
-};
-
-}
-}
-
-#endif /* MAD_COMMON_ACTIONMANAGER_H_ */
diff --git a/src/Common/CMakeLists.txt b/src/Common/CMakeLists.txt
index e2f7b4b..fe18760 100644
--- a/src/Common/CMakeLists.txt
+++ b/src/Common/CMakeLists.txt
@@ -6,7 +6,7 @@ include_directories(${INCLUDES})
link_directories(${LTDL_LIBRARY_DIR})
add_library(Common
- ActionManager.cpp Base64Encoder.cpp ClientConnection.cpp ConfigEntry.cpp
+ Base64Encoder.cpp ClientConnection.cpp ConfigEntry.cpp
ConfigManager.cpp Connection.cpp Initializable.cpp Logger.cpp LogManager.cpp
ModuleManager.cpp Request.cpp RequestManager.cpp SystemManager.cpp Tokenizer.cpp
XmlPacket.cpp
diff --git a/src/Common/ClientConnection.cpp b/src/Common/ClientConnection.cpp
index 381f822..d061c8d 100644
--- a/src/Common/ClientConnection.cpp
+++ b/src/Common/ClientConnection.cpp
@@ -33,8 +33,8 @@ bool ClientConnection::send(const Net::Packet &packet) {
return connection->send(packet);
}
-void ClientConnection::connect(const Net::IPAddress &address, bool daemon) throw(Net::Exception) {
- connection->connect(address, daemon);
+void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Net::Exception) {
+ connection->connect(address);
}
bool ClientConnection::isConnecting() const {
@@ -50,7 +50,7 @@ bool ClientConnection::disconnect() {
return true;
}
-void* ClientConnection::getCertificate(size_t *size) const {
+/*void* ClientConnection::getCertificate(size_t *size) const {
const gnutls_datum_t *cert = connection->getCertificate();
*size = cert->size;
@@ -62,7 +62,7 @@ void* ClientConnection::getPeerCertificate(size_t *size) const {
*size = cert->size;
return cert->data;
-}
+}*/
}
}
diff --git a/src/Common/ClientConnection.h b/src/Common/ClientConnection.h
index 09ca4db..4710bd4 100644
--- a/src/Common/ClientConnection.h
+++ b/src/Common/ClientConnection.h
@@ -23,11 +23,12 @@
#include "Connection.h"
#include <Net/Exception.h>
+#include <boost/asio.hpp>
+
namespace Mad {
namespace Net {
class ClientConnection;
-class IPAddress;
}
namespace Common {
@@ -43,14 +44,14 @@ class ClientConnection : public Connection {
ClientConnection();
virtual ~ClientConnection() {}
- void connect(const Net::IPAddress &address, bool daemon = false) throw(Net::Exception);
+ void connect(const boost::asio::ip::tcp::endpoint &address) throw(Net::Exception);
bool isConnecting() const;
bool isConnected() const;
virtual bool disconnect();
- virtual void* getCertificate(size_t *size) const;
- virtual void* getPeerCertificate(size_t *size) const;
+ //virtual void* getCertificate(size_t *size) const;
+ //virtual void* getPeerCertificate(size_t *size) const;
};
}
diff --git a/src/Common/Connection.h b/src/Common/Connection.h
index ea90c4c..0cfc742 100644
--- a/src/Common/Connection.h
+++ b/src/Common/Connection.h
@@ -62,8 +62,8 @@ class Connection {
}
virtual bool disconnect() = 0;
- virtual void* getCertificate(size_t *size) const = 0;
- virtual void* getPeerCertificate(size_t *size) const = 0;
+ //virtual void* getCertificate(size_t *size) const = 0;
+ //virtual void* getPeerCertificate(size_t *size) const = 0;
virtual
diff --git a/src/Common/Requests/CMakeLists.txt b/src/Common/Requests/CMakeLists.txt
index 1b632e2..1d49f0e 100644
--- a/src/Common/Requests/CMakeLists.txt
+++ b/src/Common/Requests/CMakeLists.txt
@@ -1,6 +1,6 @@
include_directories(${INCLUDES})
add_library(Requests
- DisconnectRequest.cpp GSSAPIAuthRequest.cpp SimpleRequest.cpp UserInfoRequest.cpp
+ DisconnectRequest.cpp IdentifyRequest.cpp SimpleRequest.cpp UserInfoRequest.cpp
)
target_link_libraries(Requests ${KRB5_LIBRARIES})
diff --git a/src/Daemon/Requests/IdentifyRequest.cpp b/src/Common/Requests/IdentifyRequest.cpp
index dd035bf..6cf09d1 100644
--- a/src/Daemon/Requests/IdentifyRequest.cpp
+++ b/src/Common/Requests/IdentifyRequest.cpp
@@ -19,16 +19,16 @@
#include "IdentifyRequest.h"
-#include <Common/Logger.h>
-
namespace Mad {
-namespace Daemon {
+namespace Common {
namespace Requests {
void IdentifyRequest::sendRequest() {
Common::XmlPacket packet;
packet.setType("Identify");
- packet.add("hostname", hostname);
+
+ if(!hostname.empty())
+ packet.add("hostname", hostname);
sendPacket(packet);
}
diff --git a/src/Daemon/Requests/IdentifyRequest.h b/src/Common/Requests/IdentifyRequest.h
index 51a30e3..3798909 100644
--- a/src/Daemon/Requests/IdentifyRequest.h
+++ b/src/Common/Requests/IdentifyRequest.h
@@ -20,11 +20,12 @@
#ifndef MAD_DAEMON_REQUESTS_IDENTIFYREQUEST_H_
#define MAD_DAEMON_REQUESTS_IDENTIFYREQUEST_H_
-#include <Common/Request.h>
+#include "../Request.h"
+
#include <string>
namespace Mad {
-namespace Daemon {
+namespace Common {
namespace Requests {
class IdentifyRequest : public Common::Request {
@@ -35,7 +36,7 @@ class IdentifyRequest : public Common::Request {
virtual void sendRequest();
public:
- IdentifyRequest(Common::Connection *connection, uint16_t requestId, slot_type slot, const std::string &hostname0)
+ IdentifyRequest(Common::Connection *connection, uint16_t requestId, slot_type slot, const std::string &hostname0 = std::string())
: Common::Request(connection, requestId, slot), hostname(hostname0) {}
};
diff --git a/src/Daemon/Requests/CMakeLists.txt b/src/Daemon/Requests/CMakeLists.txt
index 99fbd45..74d8072 100644
--- a/src/Daemon/Requests/CMakeLists.txt
+++ b/src/Daemon/Requests/CMakeLists.txt
@@ -1,5 +1,5 @@
include_directories(${INCLUDES})
add_library(DaemonRequests
- IdentifyRequest.cpp LogRequest.cpp
+ LogRequest.cpp
)
diff --git a/src/Net/CMakeLists.txt b/src/Net/CMakeLists.txt
index aa3f857..fae358b 100644
--- a/src/Net/CMakeLists.txt
+++ b/src/Net/CMakeLists.txt
@@ -2,7 +2,7 @@ include_directories(${INCLUDES})
link_directories(${Boost_LIBRARY_DIRS})
add_library(Net
- ClientConnection.cpp Connection.cpp Exception.cpp FdManager.cpp IPAddress.cpp
- Listener.cpp Packet.cpp ServerConnection.cpp ThreadManager.cpp
+ ClientConnection.cpp Connection.cpp Exception.cpp Listener.cpp
+ Packet.cpp ThreadManager.cpp
)
-target_link_libraries(Net ${Boost_LIBRARIES} ${GNUTLS_LIBRARIES})
+target_link_libraries(Net ${Boost_LIBRARIES} ${OPENSSL_LIBRARIES})
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp
index 087d95f..9cdf796 100644
--- a/src/Net/ClientConnection.cpp
+++ b/src/Net/ClientConnection.cpp
@@ -18,99 +18,36 @@
*/
#include "ClientConnection.h"
-#include "FdManager.h"
-#include "IPAddress.h"
-#include <boost/thread/locks.hpp>
-
-#include <cstring>
-#include <cerrno>
-#include <sys/socket.h>
-#include <fcntl.h>
+#include <Common/Logger.h>
namespace Mad {
namespace Net {
-// TODO Error handling
-void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) {
- if(length != sizeof(ConnectionHeader))
- // Error... disconnect
- return;
-
- const ConnectionHeader *header = (const ConnectionHeader*)(data);
-
- if(header->m != 'M' || header->a != 'A' || header->d != 'D')
- // Error... disconnect
+void ClientConnection::handleConnect(const boost::system::error_code& error) {
+ if(error) {
+ // TODO Error handling
+ doDisconnect();
return;
+ }
- if(header->protVerMin != 1)
- // Unsupported protocol... disconnect
- return;
-
- enterReceiveLoop();
-}
-
-void ClientConnection::connectionHeader() {
- ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1};
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- rawSend((uint8_t*)&header, sizeof(header));
- rawReceive(sizeof(ConnectionHeader), boost::bind(&ClientConnection::connectionHeaderReceiveHandler, this, _1, _2));
+ socket.async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error));
}
-void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exception) {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
-
- daemon = daemon0;
+void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception) {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
if(_isConnected()) {
return;
// TODO Error
}
- sock = socket(PF_INET, SOCK_STREAM, 0);
- if(sock < 0) {
- throw Exception("socket()", Exception::INTERNAL_ERRNO, errno);
- }
-
- if(peer)
- delete peer;
- peer = new IPAddress(address);
-
- if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) {
- close(sock);
- delete peer;
- peer = 0;
-
- throw Exception("connect()", Exception::INTERNAL_ERRNO, errno);
- }
-
- // Set non-blocking flag
- int flags = fcntl(sock, F_GETFL, 0);
-
- if(flags < 0) {
- close(sock);
-
- throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno);
- }
-
- fcntl(sock, F_SETFL, flags | O_NONBLOCK);
-
- // Don't linger
- struct linger linger = {1, 0};
- setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger));
-
- gnutls_init(&session, GNUTLS_CLIENT);
- gnutls_set_default_priority(session);
- gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred);
- gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock);
-
- FdManager::get()->registerFd(sock, boost::bind(&ClientConnection::sendReceive, this, _1));
-
+ peer = address;
state = CONNECT;
- lock.unlock();
-
- updateEvents();
+ socket.lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error));
}
}
diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h
index bdd7872..f93c2fc 100644
--- a/src/Net/ClientConnection.h
+++ b/src/Net/ClientConnection.h
@@ -23,24 +23,27 @@
#include "Connection.h"
#include "Exception.h"
+#include <boost/utility/base_from_member.hpp>
+
+
namespace Mad {
namespace Net {
class IPAddress;
-class ClientConnection : public Connection {
+class ClientConnection : private boost::base_from_member<boost::asio::ssl::context>, public Connection {
private:
- bool daemon;
-
- void connectionHeaderReceiveHandler(const void *data, unsigned long length);
-
- protected:
- virtual void connectionHeader();
+ void handleConnect(const boost::system::error_code& error);
public:
- ClientConnection() : daemon(0) {}
-
- void connect(const IPAddress &address, bool daemon0 = false) throw(Exception);
+ ClientConnection()
+ : boost::base_from_member<boost::asio::ssl::context>(boost::ref(Connection::ioService), boost::asio::ssl::context::sslv23),
+ Connection(member)
+ {
+ member.set_verify_mode(boost::asio::ssl::context::verify_none);
+ }
+
+ void connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception);
};
}
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp
index 4e5029f..b9691cb 100644
--- a/src/Net/Connection.cpp
+++ b/src/Net/Connection.cpp
@@ -18,361 +18,225 @@
*/
#include "Connection.h"
-#include "FdManager.h"
-#include "IPAddress.h"
#include "ThreadManager.h"
-#include <cstring>
-#include <sys/socket.h>
+#include <Common/Logger.h>
+#include <cstring>
#include <boost/bind.hpp>
namespace Mad {
namespace Net {
-
-Connection::StaticInit Connection::staticInit;
+boost::asio::io_service Connection::ioService;
Connection::~Connection() {
if(_isConnected())
doDisconnect();
-
- if(transR.data)
- delete [] transR.data;
-
- while(!_sendQueueEmpty()) {
- delete [] transS.front().data;
- transS.pop();
- }
-
- gnutls_certificate_free_credentials(x509_cred);
-
- if(peer)
- delete peer;
-}
-
-void Connection::handshake() {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
- if(state != CONNECT)
- return;
-
- state = HANDSHAKE;
- lock.unlock();
-
- doHandshake();
}
-void Connection::bye() {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
- if(state != DISCONNECT)
- return;
-
- state = BYE;
- lock.unlock();
+void Connection::handleHandshake(const boost::system::error_code& error) {
+ if(error) {
+ Common::Logger::logf("Error: %s", error.message().c_str());
- doBye();
-}
-
-void Connection::doHandshake() {
- boost::shared_lock<boost::shared_mutex> lock(stateLock);
- if(state != HANDSHAKE)
+ // TODO Error handling
+ doDisconnect();
return;
+ }
- int ret = gnutls_handshake(session);
- if(ret < 0) {
- lock.unlock();
+ {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
+ state = CONNECTED;
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
- updateEvents();
- return;
- }
+ receiving = false;
+ sending = 0;
- // TODO: Error
- doDisconnect();
- return;
+ received = 0;
}
- state = CONNECTION_HEADER;
- lock.unlock();
+ ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal));
- connectionHeader();
+ enterReceiveLoop();
}
-void Connection::doBye() {
- if(state != BYE)
- return;
-
- int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR);
- if(ret < 0) {
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
- updateEvents();
- return;
- }
+void Connection::handleShutdown(const boost::system::error_code& error) {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- // TODO: Error
- doDisconnect();
- return;
+ if(error) {
+ // TODO Error
}
- doDisconnect();
+ state = DISCONNECTED;
+ ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal));
}
void Connection::enterReceiveLoop() {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
-
- if(!_isConnected() || _isDisconnecting())
- return;
-
- if(_isConnecting())
- ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal));
+ {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- state = PACKET_HEADER;
- lock.unlock();
+ if(!_isConnected() || _isDisconnecting())
+ return;
+ }
- rawReceive(sizeof(Packet::Data), boost::bind(&Connection::packetHeaderReceiveHandler, this, _1, _2));
+ rawReceive(sizeof(Packet::Data), boost::bind(&Connection::handleHeaderReceive, this, _1));
}
-void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) {
- if(state != PACKET_HEADER)
- return;
+void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) {
+ {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- if(length != sizeof(Packet::Data)) {
- // TODO: Error
- doDisconnect();
- return;
+ header = *reinterpret_cast<const Packet::Data*>(data.data());
}
- header = *(const Packet::Data*)data;
-
if(header.length == 0) {
ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId))));
enterReceiveLoop();
}
else {
- state = PACKET_DATA;
- rawReceive(ntohs(header.length), boost::bind(&Connection::packetDataReceiveHandler, this, _1, _2));
+ rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, this, _1));
}
}
-void Connection::packetDataReceiveHandler(const void *data, unsigned long length) {
- if(state != PACKET_DATA)
- return;
+void Connection::handleDataReceive(const std::vector<boost::uint8_t> &data) {
+ {
+ boost::upgrade_lock<boost::shared_mutex> lock(connectionLock);
- if(length != ntohs(header.length)) {
- // TODO: Error
- doDisconnect();
- return;
+ Packet packet(ntohs(header.requestId), data.data(), ntohs(header.length));
+ ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, packet));
}
- ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId), data, length)));
-
enterReceiveLoop();
}
-void Connection::doReceive() {
- if(!isConnected())
- return;
-
- boost::unique_lock<boost::mutex> lock(receiveLock);
-
- if(_receiveComplete())
- return;
-
- ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted);
+void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > &notify) {
+ if(error || (bytes_transferred+received) < length) {
+ Common::Logger::logf(Common::Logger::VERBOSE, "Read error: %s", error.message().c_str());
- if(ret < 0) {
- lock.unlock();
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
- return;
-
- // TODO: Error
+ // TODO Error
doDisconnect();
return;
}
- transR.transmitted += ret;
-
- if(_receiveComplete()) {
- // Save data pointer, as transR.notify might start a new reception
- uint8_t *data = transR.data;
- transR.data = 0;
+ std::vector<boost::uint8_t> buffer;
- lock.unlock();
+ {
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
- transR.notify(data, transR.length);
+ if(state != CONNECTED || !receiving)
+ return;
- delete [] data;
- }
- else {
- lock.unlock();
+ buffer.insert(buffer.end(), receiveBuffer.data(), receiveBuffer.data()+length);
}
- updateEvents();
-}
-
-bool Connection::rawReceive(unsigned long length,
- const boost::function2<void,const void*,unsigned long> &notify)
-{
- if(!isConnected())
- return false;
-
- boost::unique_lock<boost::mutex> lock(receiveLock);
- if(!_receiveComplete())
- return false;
-
- transR.data = new uint8_t[length];
- transR.length = length;
- transR.transmitted = 0;
- transR.notify = notify;
+ {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- lock.unlock();
+ receiving = false;
+ received = received + bytes_transferred - length;
- updateEvents();
+ if(received)
+ std::memmove(receiveBuffer.data(), receiveBuffer.data()+length, received);
+ }
- return true;
+ notify(buffer);
}
-void Connection::doSend() {
- if(!isConnected())
- return;
+void Connection::rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > &notify) {
+ boost::upgrade_lock<boost::shared_mutex> lock(connectionLock);
- boost::unique_lock<boost::mutex> lock(sendLock);
- while(!_sendQueueEmpty()) {
- ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted,
- transS.front().length-transS.front().transmitted);
-
- if(ret < 0) {
- lock.unlock();
+ if(!_isConnected())
+ return;
- if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
- return;
+ {
+ boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock);
- // TODO: Error
- doDisconnect();
+ if(receiving)
return;
- }
- transS.front().transmitted += ret;
+ receiving = true;
- if(transS.front().transmitted == transS.front().length) {
- delete [] transS.front().data;
- transS.pop();
+ if(length > received) {
+ boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer.data()+received, receiveBuffer.size()-received), boost::asio::transfer_at_least(length),
+ boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred,
+ length, notify));
+
+ return;
}
}
lock.unlock();
- updateEvents();
+ handleRead(boost::system::error_code(), 0, length, notify);
}
-bool Connection::rawSend(const uint8_t *data, unsigned long length) {
- if(!isConnected())
- return false;
+void Connection::handleWrite(const boost::system::error_code& error, std::size_t) {
+ {
+ boost::unique_lock<boost::shared_mutex> lock(connectionLock);
- Transmission trans = {length, 0, new uint8_t[length], boost::function2<void,const void*,unsigned long>()};
- std::memcpy(trans.data, data, length);
+ sending--;
- sendLock.lock();
- transS.push(trans);
- sendLock.unlock();
-
- updateEvents();
+ if(state == DISCONNECT && !sending) {
+ lock.unlock();
+ doDisconnect();
+ return;
+ }
+ }
- return true;
-}
+ if(error) {
+ Common::Logger::logf(Common::Logger::VERBOSE, "Write error: %s", error.message().c_str());
-void Connection::sendReceive(short events) {
- if(events & POLLHUP || events & POLLERR) {
+ // TODO Error
doDisconnect();
- return;
}
+}
- switch(state) {
- case CONNECT:
- handshake();
- return;
- case HANDSHAKE:
- doHandshake();
- return;
- case DISCONNECT:
- if(!_sendQueueEmpty())
- break;
+void Connection::rawSend(const uint8_t *data, std::size_t length) {
+ boost::upgrade_lock<boost::shared_mutex> lock(connectionLock);
- bye();
- return;
- case BYE:
- doBye();
- return;
- default:
- break;
- }
+ if(!_isConnected())
+ return;
- if(events & POLLIN)
- doReceive();
+ {
+ boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock);
- if(events & POLLOUT)
- doSend();
+ sending++;
+ boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred));
+ }
}
bool Connection::send(const Packet &packet) {
- stateLock.lock_shared();
- bool err = (!_isConnected() || _isConnecting() || _isDisconnecting());
- stateLock.unlock_shared();
-
- if(err)
- return false;
+ {
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
+ if(!_isConnected() || _isConnecting() || _isDisconnecting())
+ return false;
+ }
- return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength());
+ rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength());
+ return true;
}
void Connection::disconnect() {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
- if(!_isConnected() || _isDisconnecting())
- return;
+ {
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
+ if(!_isConnected() || _isDisconnecting())
+ return;
- state = DISCONNECT;
+ state = DISCONNECT;
- lock.unlock();
+ if(sending)
+ return;
+ }
- updateEvents();
+ doDisconnect();
}
void Connection::doDisconnect() {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
-
- if(_isConnected()) {
- FdManager::get()->unregisterFd(sock);
-
- shutdown(sock, SHUT_RDWR);
- close(sock);
+ boost::lock_guard<boost::shared_mutex> lock(connectionLock);
- gnutls_deinit(session);
-
- ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal));
-
- state = DISCONNECTED;
- }
-}
-
-void Connection::updateEvents() {
- receiveLock.lock();
- short events = (_receiveComplete() ? 0 : POLLIN);
- receiveLock.unlock();
-
- sendLock.lock();
- events |= (_sendQueueEmpty() ? 0 : POLLOUT);
- sendLock.unlock();
-
- stateLock.lock_shared();
- if(state == HANDSHAKE || state == BYE)
- events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT);
- else if(state == CONNECT || state == DISCONNECT)
- events |= POLLOUT;
-
- FdManager::get()->setFdEvents(sock, events);
- stateLock.unlock_shared();
+ if(_isConnected())
+ socket.async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error));
}
}
diff --git a/src/Net/Connection.h b/src/Net/Connection.h
index a0b95ea..303485d 100644
--- a/src/Net/Connection.h
+++ b/src/Net/Connection.h
@@ -24,155 +24,111 @@
#include "Packet.h"
-#include <queue>
-#include <string>
-#include <gnutls/gnutls.h>
-#include <poll.h>
+#include <boost/asio.hpp>
+#include <boost/asio/ssl.hpp>
#include <boost/signal.hpp>
-#include <boost/thread/mutex.hpp>
-#include <boost/thread/shared_mutex.hpp>
-#include <iostream>
+#include <boost/thread/shared_mutex.hpp>
namespace Mad {
namespace Net {
-class IPAddress;
-class Packet;
+class ThreadManager;
-class Connection {
+class Connection : boost::noncopyable {
private:
- class StaticInit {
- public:
- StaticInit() {
- gnutls_global_init();
- }
+ friend class ThreadManager;
- ~StaticInit() {
- gnutls_global_deinit();
- }
- };
- static StaticInit staticInit;
+ class Buffer {
+ public:
+ Buffer(const uint8_t *data0, std::size_t length) : data(new std::vector<uint8_t>(data0, data0+length)), buffer(boost::asio::buffer(*data)) {}
- struct Transmission {
- unsigned long length;
- unsigned long transmitted;
+ typedef boost::asio::const_buffer value_type;
+ typedef const boost::asio::const_buffer* const_iterator;
- uint8_t *data;
+ const boost::asio::const_buffer* begin() const { return &buffer; }
+ const boost::asio::const_buffer* end() const { return &buffer + 1; }
- boost::function2<void,const void*,unsigned long> notify;
+ private:
+ boost::shared_ptr<std::vector<uint8_t> > data;
+ boost::asio::const_buffer buffer;
};
- boost::mutex receiveLock;
- Transmission transR;
- boost::mutex sendLock;
- std::queue<Transmission> transS;
+ std::vector<boost::uint8_t> receiveBuffer;
+ std::size_t received;
Packet::Data header;
- boost::signal1<void,const Packet&> receiveSignal;
+ boost::signal1<void, const Packet&> receiveSignal;
boost::signal0<void> connectedSignal;
boost::signal0<void> disconnectedSignal;
- void doHandshake();
-
- void packetHeaderReceiveHandler(const void *data, unsigned long length);
- void packetDataReceiveHandler(const void *data, unsigned long length);
+ bool receiving;
+ unsigned long sending;
- void doReceive();
- void doSend();
-
- void doBye();
-
- void doDisconnect();
+ void enterReceiveLoop();
- bool _receiveComplete() const {
- return (transR.length == transR.transmitted);
- }
+ void handleHeaderReceive(const std::vector<boost::uint8_t> &data);
+ void handleDataReceive(const std::vector<boost::uint8_t> &data);
- bool _sendQueueEmpty() const {return transS.empty();}
+ void handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > &notify);
+ void handleWrite(const boost::system::error_code& error, std::size_t);
- void bye();
+ void handleShutdown(const boost::system::error_code& error);
- // Prevent shallow copy
- Connection(const Connection &o);
- Connection& operator=(const Connection &o);
+ void rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > &notify);
+ void rawSend(const uint8_t *data, std::size_t length);
protected:
- struct ConnectionHeader {
- uint8_t m;
- uint8_t a;
- uint8_t d;
- uint8_t type;
-
- uint8_t versionMajor;
- uint8_t versionMinor;
- uint8_t protVerMin;
- uint8_t protVerMax;
- };
+ static boost::asio::io_service ioService;
- boost::shared_mutex stateLock;
+ boost::shared_mutex connectionLock;
enum State {
- DISCONNECTED, CONNECT, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE
+ DISCONNECTED, CONNECT, CONNECTED, DISCONNECT
} state;
- int sock;
- gnutls_session_t session;
- gnutls_certificate_credentials_t x509_cred;
- IPAddress *peer;
+ boost::asio::ssl::stream<boost::asio::ip::tcp::socket> socket;
+ boost::asio::ip::tcp::endpoint peer;
- void handshake();
-
- virtual void connectionHeader() = 0;
-
- bool rawReceive(unsigned long length, const boost::function2<void,const void*,unsigned long> &notify);
- bool rawSend(const uint8_t *data, unsigned long length);
-
- void enterReceiveLoop();
-
- void sendReceive(short events);
+ void handleHandshake(const boost::system::error_code& error);
bool _isConnected() const {return (state != DISCONNECTED);}
bool _isConnecting() const {
- return (state == CONNECT || state == HANDSHAKE || state == CONNECTION_HEADER);
+ return (state == CONNECT);
}
bool _isDisconnecting() const {
- return (state == DISCONNECT || state == BYE);
+ return (state == DISCONNECT);
}
- void updateEvents();
-
- public:
- Connection() : state(DISCONNECTED), peer(0) {
- transR.length = transR.transmitted = 0;
- transR.data = 0;
+ void doDisconnect();
- gnutls_certificate_allocate_credentials(&x509_cred);
- }
+ Connection(boost::asio::ssl::context &sslContext) :
+ receiveBuffer(1024*1024), state(DISCONNECTED), socket(ioService, sslContext) {}
+ public:
virtual ~Connection();
bool isConnected() {
- boost::shared_lock<boost::shared_mutex> lock(stateLock);
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
return _isConnected();
}
bool isConnecting() {
- boost::shared_lock<boost::shared_mutex> lock(stateLock);
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
return _isConnecting();
}
bool isDisconnecting() {
- boost::shared_lock<boost::shared_mutex> lock(stateLock);
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
return _isDisconnecting();
}
- const gnutls_datum_t* getCertificate() const {
+ /*const gnutls_datum_t* getCertificate() const {
// TODO Thread-safeness
return gnutls_certificate_get_ours(session);
}
@@ -181,16 +137,18 @@ class Connection {
// TODO Thread-safeness
unsigned int n;
return gnutls_certificate_get_peers(session, &n);
- }
+ }*/
- // TODO Thread-safeness
- const IPAddress* getPeer() const {return peer;}
+ boost::asio::ip::tcp::endpoint getPeer() {
+ boost::shared_lock<boost::shared_mutex> lock(connectionLock);
+ return peer;
+ }
void disconnect();
bool send(const Packet &packet);
- boost::signal1<void,const Packet&>& signalReceive() {return receiveSignal;}
+ boost::signal1<void, const Packet&>& signalReceive() {return receiveSignal;}
boost::signal0<void>& signalConnected() {return connectedSignal;}
boost::signal0<void>& signalDisconnected() {return disconnectedSignal;}
};
diff --git a/src/Net/Exception.cpp b/src/Net/Exception.cpp
index 34b8033..e082948 100644
--- a/src/Net/Exception.cpp
+++ b/src/Net/Exception.cpp
@@ -20,7 +20,6 @@
#include "Exception.h"
#include <cstring>
-#include <gnutls/gnutls.h>
namespace Mad {
namespace Net {
@@ -46,8 +45,6 @@ std::string Exception::strerror() const {
return ret + "Not implemented";
case INTERNAL_ERRNO:
return ret + std::strerror(subCode);
- case INTERNAL_GNUTLS:
- return ret + "GnuTLS error: " + gnutls_strerror(subCode);
case INVALID_ADDRESS:
return ret + "Invalid address";
case ALREADY_IDENTIFIED:
diff --git a/src/Net/Exception.h b/src/Net/Exception.h
index 48e86d1..8522528 100644
--- a/src/Net/Exception.h
+++ b/src/Net/Exception.h
@@ -29,7 +29,7 @@ class Exception {
public:
enum ErrorCode {
SUCCESS = 0x0000, UNEXPECTED_PACKET = 0x0001, INVALID_ACTION = 0x0002, NOT_AVAILABLE = 0x0003, NOT_FINISHED = 0x0004, NOT_IMPLEMENTED = 0x0005,
- INTERNAL_ERRNO = 0x0010, INTERNAL_GNUTLS = 0x0011,
+ INTERNAL_ERRNO = 0x0010,
INVALID_ADDRESS = 0x0020,
ALREADY_IDENTIFIED = 0x0030, UNKNOWN_DAEMON = 0x0031
};
diff --git a/src/Net/FdManager.cpp b/src/Net/FdManager.cpp
deleted file mode 100644
index d8faef4..0000000
--- a/src/Net/FdManager.cpp
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * FdManager.cpp
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "FdManager.h"
-#include "ThreadManager.h"
-
-#include <signal.h>
-#include <unistd.h>
-#include <sys/fcntl.h>
-
-
-namespace Mad {
-namespace Net {
-
-FdManager FdManager::fdManager;
-
-
-FdManager::FdManager() : running(false) {
- pipe(interruptPipe);
-
- int flags = fcntl(interruptPipe[0], F_GETFL, 0);
- fcntl(interruptPipe[0], F_SETFL, flags | O_NONBLOCK);
-
- flags = fcntl(interruptPipe[1], F_GETFL, 0);
- fcntl(interruptPipe[1], F_SETFL, flags | O_NONBLOCK);
-
- registerFd(interruptPipe[0], boost::bind(&FdManager::readInterrupt, this), POLLIN);
-}
-
-FdManager::~FdManager() {
- unregisterFd(interruptPipe[0]);
-
- close(interruptPipe[0]);
- close(interruptPipe[1]);
-}
-
-
-bool FdManager::registerFd(int fd, const boost::function1<void, short> &handler, short events) {
- struct pollfd pollfd = {fd, events, 0};
-
- boost::lock(handlerLock, eventLock);
- pollfds.insert(std::make_pair(fd, pollfd));
-
- bool ret = handlers.insert(std::make_pair(fd, handler)).second;
-
- eventLock.unlock();
- handlerLock.unlock();
-
- interrupt();
-
- return ret;
-}
-
-bool FdManager::unregisterFd(int fd) {
- boost::lock(handlerLock, eventLock);
- pollfds.erase(fd);
- bool ret = handlers.erase(fd);
- eventLock.unlock();
- handlerLock.unlock();
-
- interrupt();
-
- return ret;
-}
-
-bool FdManager::setFdEvents(int fd, short events) {
- boost::unique_lock<boost::shared_mutex> lock(eventLock);
-
- std::map<int, struct pollfd>::iterator pollfd = pollfds.find(fd);
-
- if(pollfd == pollfds.end())
- return false;
-
- if(pollfd->second.events != events) {
- pollfd->second.events = events;
- interrupt();
- }
-
- return true;
-}
-
-short FdManager::getFdEvents(int fd) {
- boost::shared_lock<boost::shared_mutex> lock(eventLock);
-
- std::map<int, struct pollfd>::const_iterator pollfd = pollfds.find(fd);
-
- if(pollfd == pollfds.end())
- return -1;
-
- return pollfd->second.events;
-}
-
-void FdManager::readInterrupt() {
- char buf[20];
-
- while(read(interruptPipe[0], buf, sizeof(buf)) > 0) {}
-}
-
-void FdManager::interrupt() {
- char buf = 0;
-
- write(interruptPipe[1], &buf, sizeof(buf));
-}
-
-void FdManager::ioThread() {
- runLock.lock();
- running = true;
- runLock.unlock_and_lock_shared();
-
- while(running) {
- runLock.unlock_shared();
-
- handlerLock.lock_shared();
- eventLock.lock_shared();
- readInterrupt();
-
- size_t count = pollfds.size();
- struct pollfd *fdarray = new struct pollfd[count];
-
- std::map<int, struct pollfd>::iterator pollfd = pollfds.begin();
-
- for(size_t n = 0; n < count; ++n) {
- fdarray[n] = pollfd->second;
- ++pollfd;
- }
-
- eventLock.unlock_shared();
- handlerLock.unlock_shared();
-
- if(poll(fdarray, count, -1) > 0) {
- handlerLock.lock_shared();
-
- std::queue<boost::function0<void> > calls;
-
- for(size_t n = 0; n < count; ++n) {
- if(fdarray[n].revents)
- calls.push(boost::bind(handlers[fdarray[n].fd], fdarray[n].revents));
- }
-
- handlerLock.unlock_shared();
-
- while(!calls.empty()) {
- calls.front()();
- calls.pop();
- }
-
- }
-
- delete [] fdarray;
-
- runLock.lock_shared();
- }
-
- runLock.unlock_shared();
-}
-
-}
-}
diff --git a/src/Net/FdManager.h b/src/Net/FdManager.h
deleted file mode 100644
index 1cb95bc..0000000
--- a/src/Net/FdManager.h
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * FdManager.h
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef MAD_NET_FDMANAGER_H_
-#define MAD_NET_FDMANAGER_H_
-
-#include <map>
-#include <poll.h>
-
-#include <boost/function.hpp>
-#include <boost/thread/shared_mutex.hpp>
-
-namespace Mad {
-namespace Net {
-
-class ThreadManager;
-
-class FdManager {
- private:
- friend class ThreadManager;
-
- static FdManager fdManager;
-
- boost::shared_mutex runLock, handlerLock, eventLock;
- bool running;
-
- std::map<int, struct pollfd> pollfds;
- std::map<int, boost::function1<void, short> > handlers;
-
- int interruptPipe[2];
-
- void readInterrupt();
- void interrupt();
-
- FdManager();
-
- void ioThread();
- void stopIOThread() {
- runLock.lock();
- running = false;
- runLock.unlock();
-
- interrupt();
- }
-
- public:
- virtual ~FdManager();
-
- static FdManager *get() {return &fdManager;}
-
- bool registerFd(int fd, const boost::function1<void, short> &handler, short events = 0);
- bool unregisterFd(int fd);
-
- bool setFdEvents(int fd, short events);
- short getFdEvents(int fd);
-};
-
-}
-}
-
-#endif /* MAD_NET_FDMANAGER_H_ */
diff --git a/src/Net/IPAddress.cpp b/src/Net/IPAddress.cpp
deleted file mode 100644
index eb9d3be..0000000
--- a/src/Net/IPAddress.cpp
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * IPAddress.cpp
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "IPAddress.h"
-
-#include <cstdlib>
-
-namespace Mad {
-namespace Net {
-
-IPAddress::IPAddress(uint16_t port0) : addr(INADDR_ANY), port(port0) {
- sa.sin_family = AF_INET;
- sa.sin_port = htons(port);
- sa.sin_addr.s_addr = INADDR_ANY;
-}
-
-IPAddress::IPAddress(uint32_t address, uint16_t port0) : addr(address), port(port0) {
- sa.sin_family = AF_INET;
- sa.sin_port = htons(port);
- sa.sin_addr.s_addr = htonl(addr);
-}
-
-IPAddress::IPAddress(const std::string &address) throw(Exception) {
- std::string ip;
- size_t pos = address.find_first_of(':');
-
- if(pos == std::string::npos) {
- ip = address;
- // TODO Default port
- port = 6666;
- }
- else {
- ip = address.substr(0, pos);
-
- char *endptr;
- port = std::strtol(address.substr(pos+1).c_str(), &endptr, 10);
- if(*endptr != 0 || port == 0)
- throw Exception(Exception::INVALID_ADDRESS);
- }
-
- sa.sin_family = AF_INET;
- sa.sin_port = htons(port);
-
- if(ip == "*")
- sa.sin_addr.s_addr = INADDR_ANY;
- else if(!inet_pton(AF_INET, ip.c_str(), &sa.sin_addr))
- throw Exception(Exception::INVALID_ADDRESS);
-
- addr = ntohl(sa.sin_addr.s_addr);
-}
-
-IPAddress::IPAddress(const std::string &address, uint16_t port0) throw(Exception) : port(port0) {
- sa.sin_family = AF_INET;
- sa.sin_port = htons(port);
-
- if(!inet_pton(AF_INET, address.c_str(), &sa.sin_addr))
- throw Exception(Exception::INVALID_ADDRESS);
-
- addr = ntohl(sa.sin_addr.s_addr);
-}
-
-IPAddress::IPAddress(const struct sockaddr_in &address) : sa(address) {
- port = ntohs(sa.sin_port);
- addr = ntohl(sa.sin_addr.s_addr);
-}
-
-std::string IPAddress::getAddressString() const {
- char buf[INET_ADDRSTRLEN];
- uint32_t address = htonl(addr);
-
- inet_ntop(AF_INET, &address, buf, sizeof(buf));
- return std::string(buf);
-}
-
-}
-}
diff --git a/src/Net/IPAddress.h b/src/Net/IPAddress.h
deleted file mode 100644
index 3541891..0000000
--- a/src/Net/IPAddress.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * IPAddress.h
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef MAD_NET_IPADDRESS_H_
-#define MAD_NET_IPADDRESS_H_
-
-#include "Exception.h"
-
-#include <string>
-#include <arpa/inet.h>
-#include <stdint.h>
-
-namespace Mad {
-namespace Net {
-
-class IPAddress {
- private:
- uint32_t addr;
- uint16_t port;
- struct sockaddr_in sa;
-
- public:
- // TODO Default port
- IPAddress(uint16_t port0 = 6666);
- IPAddress(uint32_t address, uint16_t port0);
- IPAddress(const std::string &address) throw(Exception);
- IPAddress(const std::string &address, uint16_t port0) throw(Exception);
- IPAddress(const struct sockaddr_in &address);
-
- uint32_t getAddress() const {return addr;}
- uint16_t getPort() const {return port;}
-
- std::string getAddressString() const;
-
- struct sockaddr* getSockAddr() {return (struct sockaddr*)&sa;}
- socklen_t getSockAddrLength() const {return sizeof(sa);}
-};
-
-}
-}
-
-#endif /*MAD_NET_IPADDRESS_H_*/
diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp
index 11cbaf5..6f49a74 100644
--- a/src/Net/Listener.cpp
+++ b/src/Net/Listener.cpp
@@ -18,8 +18,6 @@
*/
#include "Listener.h"
-#include "FdManager.h"
-#include "ServerConnection.h"
#include <Common/Logger.h>
@@ -30,26 +28,29 @@
namespace Mad {
namespace Net {
-void Listener::acceptHandler(int) {
- int sd;
- struct sockaddr_in sa;
- socklen_t addrlen = sizeof(sa);
+void Listener::handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con) {
+ if(error)
+ return;
+
+ {
+ boost::lock_guard<boost::shared_mutex> lock(con->connectionLock);
+ con->state = ServerConnection::CONNECT;
- while((sd = accept(sock, (struct sockaddr*)&sa, &addrlen)) >= 0) {
- ServerConnection *con = new ServerConnection(sd, IPAddress(sa), dh_params, x905CertFile, x905KeyFile);
- boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::connectHandler, this, con));
- boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::disconnectHandler, this, con));
+ boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::handleConnect, this, con));
+ boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::handleDisconnect, this, con));
connections.insert(std::make_pair(con, std::make_pair(con1, con2)));
- addrlen = sizeof(sa);
+ con->socket.async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&ServerConnection::handleHandshake, con, boost::asio::placeholders::error));
}
-}
+ con.reset(new ServerConnection(sslContext));
+ acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con));
+}
-void Listener::connectHandler(ServerConnection *con) {
- std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con);
+void Listener::handleConnect(boost::shared_ptr<ServerConnection> con) {
+ std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con);
if(it == connections.end())
return;
@@ -62,67 +63,33 @@ void Listener::connectHandler(ServerConnection *con) {
signal(con);
}
-void Listener::disconnectHandler(ServerConnection *con) {
- std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con);
-
- if(it == connections.end())
- return;
-
- delete it->first;
- connections.erase(it);
+void Listener::handleDisconnect(boost::shared_ptr<ServerConnection> con) {
+ connections.erase(con);
}
-Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0) throw(Exception)
-: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0) {
- gnutls_dh_params_init(&dh_params);
- gnutls_dh_params_generate2(dh_params, 768);
+Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0,
+ const boost::asio::ip::tcp::endpoint &address0) throw(Exception)
+: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0), acceptor(Connection::ioService, address),
+sslContext(Connection::ioService, boost::asio::ssl::context::sslv23)
+{
+ sslContext.set_options(boost::asio::ssl::context::default_workarounds
+ | boost::asio::ssl::context::no_sslv2
+ | boost::asio::ssl::context::single_dh_use);
+ sslContext.use_certificate_chain_file(x905CertFile0);
+ sslContext.use_private_key_file(x905KeyFile0, boost::asio::ssl::context::pem);
- sock = socket(PF_INET, SOCK_STREAM, 0);
- if(sock < 0)
- throw Exception("socket()", Exception::INTERNAL_ERRNO, errno);
- // Set non-blocking flag
- int flags = fcntl(sock, F_GETFL, 0);
-
- if(flags < 0) {
- close(sock);
-
- throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno);
- }
-
- fcntl(sock, F_SETFL, flags | O_NONBLOCK);
-
- // Don't linger
- struct linger linger = {1, 0};
- setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger));
-
- if(bind(sock, address.getSockAddr(), address.getSockAddrLength()) < 0) {
- close(sock);
-
- throw Exception("bind()", Exception::INTERNAL_ERRNO, errno);
- }
-
- if(listen(sock, 64) < 0) {
- close(sock);
-
- throw Exception("listen()", Exception::INTERNAL_ERRNO, errno);
- }
-
- FdManager::get()->registerFd(sock, boost::bind(&Listener::acceptHandler, this, _1), POLLIN);
+ boost::shared_ptr<ServerConnection> con(new ServerConnection(sslContext));
+ acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con));
}
Listener::~Listener() {
- for(std::map<ServerConnection*,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) {
+ for(std::map<boost::shared_ptr<ServerConnection>,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) {
con->first->disconnect();
- delete con->first;
+ // TODO wait...
}
-
- shutdown(sock, SHUT_RDWR);
- close(sock);
-
- gnutls_dh_params_deinit(dh_params);
}
}
diff --git a/src/Net/Listener.h b/src/Net/Listener.h
index 26dffab..0833cdf 100644
--- a/src/Net/Listener.h
+++ b/src/Net/Listener.h
@@ -20,46 +20,45 @@
#ifndef MAD_NET_LISTENER_H_
#define MAD_NET_LISTENER_H_
-#include "IPAddress.h"
-
-#include <gnutls/gnutls.h>
#include <map>
#include <string>
-#include <boost/signal.hpp>
+#include "Connection.h"
+#include "Exception.h"
namespace Mad {
namespace Net {
-class ServerConnection;
-
// TODO XXX Thread-safeness XXX
-class Listener {
+class Listener : boost::noncopyable {
private:
- std::string x905CertFile, x905KeyFile;
- IPAddress address;
- int sock;
+ class ServerConnection : public Connection {
+ public:
+ friend class Listener;
- gnutls_dh_params_t dh_params;
+ ServerConnection(boost::asio::ssl::context &sslContext) : Connection(sslContext) {}
+ };
- std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> > connections;
+ std::string x905CertFile, x905KeyFile;
+ boost::asio::ip::tcp::endpoint address;
+ boost::asio::ip::tcp::acceptor acceptor;
+ boost::asio::ssl::context sslContext;
- boost::signal1<void, ServerConnection*> signal;
+ std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> > connections;
- void acceptHandler(int);
+ boost::signal1<void, boost::shared_ptr<Connection> > signal;
- void connectHandler(ServerConnection *con);
- void disconnectHandler(ServerConnection *con);
+ void handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con);
- // Prevent shallow copy
- Listener(const Listener &o);
- Listener& operator=(const Listener &o);
+ void handleConnect(boost::shared_ptr<ServerConnection> con);
+ void handleDisconnect(boost::shared_ptr<ServerConnection> con);
public:
- Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0 = IPAddress()) throw(Exception);
+ Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0,
+ const boost::asio::ip::tcp::endpoint &address0 = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 6666)) throw(Exception);
virtual ~Listener();
- boost::signal1<void, ServerConnection*>& signalNewConnection() {return signal;}
+ boost::signal1<void, boost::shared_ptr<Connection> >& signalNewConnection() {return signal;}
};
}
diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp
deleted file mode 100644
index 1f01ce5..0000000
--- a/src/Net/ServerConnection.cpp
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * ServerConnection.cpp
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "ServerConnection.h"
-#include "FdManager.h"
-#include "IPAddress.h"
-
-#include <boost/thread/locks.hpp>
-
-#include <cstring>
-#include <cerrno>
-#include <sys/socket.h>
-#include <fcntl.h>
-
-namespace Mad {
-namespace Net {
-
-void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) {
- if(length != sizeof(ConnectionHeader))
- // Error... disconnect
- return;
-
- const ConnectionHeader *header = (const ConnectionHeader*)data;
-
- if(header->m != 'M' || header->a != 'A' || header->d != 'D')
- // Error... disconnect
- return;
-
- if(header->protVerMin > 1 || header->protVerMax < 1)
- // Unsupported protocol... disconnect
- return;
-
- if(header->type == 'C')
- daemon = false;
- else if(header->type == 'D')
- daemon = true;
- else
- // Error... disconnect
- return;
-
- ConnectionHeader header2 = {'M', 'A', 'D', 0, 0, 1, 1, 0};
-
- enterReceiveLoop();
-
- rawSend((uint8_t*)&header2, sizeof(header2));
-}
-
-ServerConnection::ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905CertFile, const std::string &x905KeyFile)
-: daemon(false) {
- boost::unique_lock<boost::shared_mutex> lock(stateLock);
-
- sock = sock0;
-
- peer = new IPAddress(address);
-
- gnutls_certificate_set_dh_params(x509_cred, dh_params);
- gnutls_certificate_set_x509_key_file(x509_cred, x905CertFile.c_str(), x905KeyFile.c_str(), GNUTLS_X509_FMT_PEM);
-
- gnutls_init(&session, GNUTLS_SERVER);
- gnutls_set_default_priority(session);
- gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred);
- gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock);
-
- FdManager::get()->registerFd(sock, boost::bind(&ServerConnection::sendReceive, this, _1));
-
- state = CONNECT;
-
- lock.unlock();
-
- updateEvents();
-}
-
-}
-}
diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h
deleted file mode 100644
index d52cd7c..0000000
--- a/src/Net/ServerConnection.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * ServerConnection.h
- *
- * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de>
- *
- * This program is free software: you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but
- * WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
- * See the GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef MAD_NET_SERVERCONNECTION_H_
-#define MAD_NET_SERVERCONNECTION_H_
-
-#include "Connection.h"
-#include <string>
-
-namespace Mad {
-namespace Net {
-
-class Listener;
-
-class ServerConnection : public Connection {
- friend class Listener;
-
- private:
- IPAddress *peer;
-
- bool daemon;
-
- gnutls_anon_server_credentials_t anoncred;
-
- void connectionHeaderReceiveHandler(const void *data, unsigned long length);
-
- protected:
- ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905certFile, const std::string &x905keyFile);
-
- virtual void connectionHeader() {
- rawReceive(sizeof(ConnectionHeader), boost::bind(&ServerConnection::connectionHeaderReceiveHandler, this, _1, _2));
- }
-
- public:
- bool isDaemonConnection() const {return daemon;}
-};
-
-}
-}
-
-#endif /*MAD_NET_SERVERCONNECTION_H_*/
diff --git a/src/Net/ThreadManager.cpp b/src/Net/ThreadManager.cpp
index 71a754e..0fb0716 100644
--- a/src/Net/ThreadManager.cpp
+++ b/src/Net/ThreadManager.cpp
@@ -18,11 +18,13 @@
*/
#include "ThreadManager.h"
-#include "FdManager.h"
+#include "Connection.h"
#include <Common/Logger.h>
#include <Common/LogManager.h>
+#include <boost/bind.hpp>
+
#include <fcntl.h>
namespace Mad {
@@ -97,10 +99,12 @@ void ThreadManager::doInit() {
threadLock.lock();
+ ioWorker.reset(new boost::asio::io_service::work(Connection::ioService));
+
mainThreadId = boost::this_thread::get_id();
- workerThread = new boost::thread(std::mem_fun(&ThreadManager::workerFunc), this);
- loggerThread = new boost::thread(std::mem_fun(&Common::LogManager::loggerThread), Common::LogManager::get());
- ioThread = new boost::thread(std::mem_fun(&FdManager::ioThread), FdManager::get());
+ workerThread = new boost::thread(&ThreadManager::workerFunc, this);
+ loggerThread = new boost::thread(&Common::LogManager::loggerThread, Common::LogManager::get());
+ ioThread = new boost::thread((std::size_t(boost::asio::io_service::*)())&boost::asio::io_service::run, &Connection::ioService);
threadLock.unlock();
}
@@ -128,7 +132,7 @@ void ThreadManager::doDeinit() {
threads.join_all();
// IO thread is next
- FdManager::get()->stopIOThread();
+ ioWorker.reset();
ioThread->join();
delete ioThread;
diff --git a/src/Net/ThreadManager.h b/src/Net/ThreadManager.h
index fd903af..2c57747 100644
--- a/src/Net/ThreadManager.h
+++ b/src/Net/ThreadManager.h
@@ -27,7 +27,8 @@
#include <queue>
#include <set>
-#include <boost/function.hpp>
+#include <boost/asio.hpp>
+
#include <boost/thread/thread.hpp>
#include <boost/thread/condition_variable.hpp>
#include <boost/thread/locks.hpp>
@@ -50,11 +51,14 @@ class ThreadManager : public Common::Initializable {
boost::condition_variable workCond;
std::queue<boost::function0<void> > work;
+ boost::scoped_ptr<boost::asio::io_service::work> ioWorker;
+
static ThreadManager threadManager;
ThreadManager() {}
void workerFunc();
+ void ioFunc();
void threadFinished(boost::thread *thread) {
threadLock.lock();
diff --git a/src/Server/ConnectionManager.cpp b/src/Server/ConnectionManager.cpp
index 6ef918a..bbaf673 100644
--- a/src/Server/ConnectionManager.cpp
+++ b/src/Server/ConnectionManager.cpp
@@ -33,8 +33,6 @@
#include "RequestHandlers/LogRequestHandler.h"
#include "RequestHandlers/UserInfoRequestHandler.h"
#include "RequestHandlers/UserListRequestHandler.h"
-#include <Net/FdManager.h>
-#include <Net/ServerConnection.h>
#include <Net/Packet.h>
#include <Net/Listener.h>
@@ -46,49 +44,46 @@ namespace Server {
ConnectionManager ConnectionManager::connectionManager;
-bool ConnectionManager::Connection::send(const Net::Packet &packet) {
+bool ConnectionManager::ServerConnection::send(const Net::Packet &packet) {
return connection->send(packet);
}
-ConnectionManager::Connection::Connection(Net::ServerConnection *connection0)
-: connection(connection0), type(connection0->isDaemonConnection() ? DAEMON : CLIENT), hostInfo(0) {
- connection->signalReceive().connect(boost::bind(&Connection::receive, this, _1));
+ConnectionManager::ServerConnection::ServerConnection(boost::shared_ptr<Net::Connection> connection0)
+: connection(connection0), type(UNKNOWN), hostInfo(0) {
+ connection->signalReceive().connect(boost::bind(&ServerConnection::receive, this, _1));
}
-ConnectionManager::Connection::~Connection() {
- delete connection;
-}
-
-bool ConnectionManager::Connection::isConnected() const {
+bool ConnectionManager::ServerConnection::isConnected() const {
return connection->isConnected();
}
-bool ConnectionManager::Connection::disconnect() {
+bool ConnectionManager::ServerConnection::disconnect() {
connection->disconnect();
return true;
}
-void* ConnectionManager::Connection::getCertificate(size_t *size) const {
+/*void* ConnectionManager::ServerConnection::getCertificate(size_t *size) const {
const gnutls_datum_t *cert = connection->getCertificate();
*size = cert->size;
return cert->data;
}
-void* ConnectionManager::Connection::getPeerCertificate(size_t *size) const {
+void* ConnectionManager::ServerConnection::getPeerCertificate(size_t *size) const {
const gnutls_datum_t *cert = connection->getPeerCertificate();
*size = cert->size;
return cert->data;
-}
+}*/
void ConnectionManager::updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state) {
hostInfo->setState(state);
- for(std::set<Connection*>::iterator con = connections.begin(); con != connections.end(); ++con) {
- if((*con)->getConnectionType() == Connection::CLIENT)
- Common::RequestManager::get()->sendRequest<Requests::DaemonStateUpdateRequest>(*con, boost::bind(&ConnectionManager::updateStateFinished, this, _1), hostInfo->getName(), state);
+
+ for(std::set<boost::shared_ptr<ServerConnection> >::iterator con = connections.begin(); con != connections.end(); ++con) {
+ if((*con)->getConnectionType() == ServerConnection::CLIENT)
+ Common::RequestManager::get()->sendRequest<Requests::DaemonStateUpdateRequest>(con->get(), boost::bind(&ConnectionManager::updateStateFinished, this, _1), hostInfo->getName(), state);
}
}
@@ -98,7 +93,7 @@ bool ConnectionManager::handleConfigEntry(const Common::ConfigEntry &entry, bool
if(entry[0].getKey().matches("Listen") && entry[1].empty()) {
try {
- listenerAddresses.push_back(Net::IPAddress(entry[0][0]));
+ listenerAddresses.push_back(boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(entry[0][0]), 6666));
}
catch(Net::Exception &e) {
// TODO Log error
@@ -147,8 +142,8 @@ bool ConnectionManager::handleConfigEntry(const Common::ConfigEntry &entry, bool
void ConnectionManager::configFinished() {
if(listenerAddresses.empty()) {
try {
- Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile);
- listener->signalNewConnection().connect(boost::bind(&ConnectionManager::newConnectionHandler, this, _1));
+ boost::shared_ptr<Net::Listener> listener(new Net::Listener(x509CertFile, x509KeyFile));
+ listener->signalNewConnection().connect(boost::bind(&ConnectionManager::handleNewConnection, this, _1));
listeners.push_back(listener);
}
catch(Net::Exception &e) {
@@ -156,10 +151,10 @@ void ConnectionManager::configFinished() {
}
}
else {
- for(std::vector<Net::IPAddress>::const_iterator address = listenerAddresses.begin(); address != listenerAddresses.end(); ++address) {
+ for(std::vector<boost::asio::ip::tcp::endpoint>::const_iterator address = listenerAddresses.begin(); address != listenerAddresses.end(); ++address) {
try {
- Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile, *address);
- listener->signalNewConnection().connect(boost::bind(&ConnectionManager::newConnectionHandler, this, _1));
+ boost::shared_ptr<Net::Listener> listener(new Net::Listener(x509CertFile, x509KeyFile, *address));
+ listener->signalNewConnection().connect(boost::bind(&ConnectionManager::handleNewConnection, this, _1));
listeners.push_back(listener);
}
catch(Net::Exception &e) {
@@ -169,28 +164,27 @@ void ConnectionManager::configFinished() {
}
}
-void ConnectionManager::newConnectionHandler(Net::ServerConnection *con) {
- Connection *connection = new Connection(con);
- con->signalDisconnected().connect(boost::bind(&ConnectionManager::disconnectHandler, this, connection));
+void ConnectionManager::handleNewConnection(boost::shared_ptr<Net::Connection> con) {
+ boost::shared_ptr<ServerConnection> connection(new ServerConnection(con));
+ con->signalDisconnected().connect(boost::bind(&ConnectionManager::handleDisconnect, this, connection));
connections.insert(connection);
- Common::RequestManager::get()->registerConnection(connection);
+ Common::RequestManager::get()->registerConnection(connection.get());
}
-void ConnectionManager::disconnectHandler(Connection *con) {
- if(con->isIdentified())
+void ConnectionManager::handleDisconnect(boost::shared_ptr<ServerConnection> con) {
+ if(con->getHostInfo())
updateState(con->getHostInfo(), Common::HostInfo::INACTIVE);
connections.erase(con);
- Common::RequestManager::get()->unregisterConnection(con);
- delete con;
+ Common::RequestManager::get()->unregisterConnection(con.get());
}
void ConnectionManager::doInit() {
Common::RequestManager::get()->setServer(true);
- Common::RequestManager::get()->registerPacketType<RequestHandlers::GSSAPIAuthRequestHandler>("AuthGSSAPI");
+ //Common::RequestManager::get()->registerPacketType<RequestHandlers::GSSAPIAuthRequestHandler>("AuthGSSAPI");
Common::RequestManager::get()->registerPacketType<RequestHandlers::DaemonCommandRequestHandler>("DaemonCommand");
Common::RequestManager::get()->registerPacketType<RequestHandlers::DaemonFSInfoRequestHandler>("DaemonFSInfo");
Common::RequestManager::get()->registerPacketType<Common::RequestHandlers::FSInfoRequestHandler>("FSInfo");
@@ -204,11 +198,9 @@ void ConnectionManager::doInit() {
}
void ConnectionManager::doDeinit() {
- for(std::set<Connection*>::iterator con = connections.begin(); con != connections.end(); ++con)
- delete *con;
+ connections.clear();
-
- Common::RequestManager::get()->unregisterPacketType("AuthGSSAPI");
+ //Common::RequestManager::get()->unregisterPacketType("AuthGSSAPI");
Common::RequestManager::get()->unregisterPacketType("DaemonCommand");
Common::RequestManager::get()->unregisterPacketType("DaemonFSInfo");
Common::RequestManager::get()->unregisterPacketType("FSInfo");
@@ -221,7 +213,7 @@ void ConnectionManager::doDeinit() {
Common::RequestManager::get()->unregisterPacketType("Log");
}
-Common::Connection* ConnectionManager::getDaemonConnection(const std::string &name) const throw (Net::Exception&) {
+boost::shared_ptr<Common::Connection> ConnectionManager::getDaemonConnection(const std::string &name) const throw (Net::Exception&) {
const Common::HostInfo *hostInfo;
try {
@@ -232,7 +224,7 @@ Common::Connection* ConnectionManager::getDaemonConnection(const std::string &na
}
if(hostInfo->getState() != Common::HostInfo::INACTIVE) {
- for(std::set<Connection*>::const_iterator it = connections.begin(); it != connections.end(); ++it) {
+ for(std::set<boost::shared_ptr<ServerConnection> >::const_iterator it = connections.begin(); it != connections.end(); ++it) {
if((*it)->getHostInfo() == hostInfo) {
return *it;
}
@@ -243,7 +235,7 @@ Common::Connection* ConnectionManager::getDaemonConnection(const std::string &na
}
std::string ConnectionManager::getDaemonName(const Common::Connection *con) const throw (Net::Exception&) {
- const Connection *connection = dynamic_cast<const Connection*>(con);
+ const ServerConnection *connection = dynamic_cast<const ServerConnection*>(con);
if(connection) {
if(connection->isIdentified()) {
@@ -257,9 +249,9 @@ std::string ConnectionManager::getDaemonName(const Common::Connection *con) cons
void ConnectionManager::identifyDaemonConnection(Common::Connection *con, const std::string &name) throw (Net::Exception&) {
// TODO Logging
- Connection *connection = dynamic_cast<Connection*>(con);
+ ServerConnection *connection = dynamic_cast<ServerConnection*>(con);
- if(!connection || (connection->getConnectionType() != Connection::DAEMON))
+ if(!connection)
throw Net::Exception(Net::Exception::INVALID_ACTION);
if(connection->isIdentified())
@@ -284,6 +276,18 @@ void ConnectionManager::identifyDaemonConnection(Common::Connection *con, const
Common::Logger::logf("Identified as '%s'.", name.c_str());
}
+void ConnectionManager::identifyClientConnection(Common::Connection *con) throw (Net::Exception&) {
+ ServerConnection *connection = dynamic_cast<ServerConnection*>(con);
+
+ if(!connection)
+ throw Net::Exception(Net::Exception::INVALID_ACTION);
+
+ if(connection->isIdentified())
+ throw Net::Exception(Net::Exception::ALREADY_IDENTIFIED);
+
+ connection->identify();
+}
+
std::vector<Common::HostInfo> ConnectionManager::getDaemonList() const {
std::vector<Common::HostInfo> ret;
diff --git a/src/Server/ConnectionManager.h b/src/Server/ConnectionManager.h
index 7d97edc..710665d 100644
--- a/src/Server/ConnectionManager.h
+++ b/src/Server/ConnectionManager.h
@@ -20,38 +20,38 @@
#ifndef MAD_SERVER_CONNECTIONMANAGER_H_
#define MAD_SERVER_CONNECTIONMANAGER_H_
-#include <list>
-#include <vector>
-#include <map>
-#include <set>
-
#include <Common/Configurable.h>
#include <Common/HostInfo.h>
#include <Common/Initializable.h>
#include <Common/RequestManager.h>
-#include <Net/IPAddress.h>
+#include <list>
+#include <vector>
+#include <map>
+#include <set>
+
+#include <boost/asio.hpp>
namespace Mad {
namespace Net {
+class Connection;
class Listener;
-class ServerConnection;
class Packet;
}
namespace Server {
-class ConnectionManager : public Common::Configurable, public Common::Initializable {
+class ConnectionManager : public Common::Configurable, public Common::Initializable, boost::noncopyable {
private:
- class Connection : public Common::Connection {
+ class ServerConnection : public Common::Connection {
public:
enum ConnectionType {
- DAEMON, CLIENT
+ UNKNOWN = 0, DAEMON, CLIENT
};
private:
- Net::ServerConnection *connection;
+ boost::shared_ptr<Net::Connection> connection;
ConnectionType type;
Common::HostInfo *hostInfo;
@@ -59,14 +59,13 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa
virtual bool send(const Net::Packet &packet);
public:
- Connection(Net::ServerConnection *connection0);
- virtual ~Connection();
+ ServerConnection(boost::shared_ptr<Net::Connection> connection0);
bool isConnected() const;
virtual bool disconnect();
- virtual void* getCertificate(size_t *size) const;
- virtual void* getPeerCertificate(size_t *size) const;
+ //virtual void* getCertificate(size_t *size) const;
+ //virtual void* getPeerCertificate(size_t *size) const;
ConnectionType getConnectionType() const {
return type;
@@ -77,10 +76,15 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa
}
bool isIdentified() const {
- return hostInfo;
+ return (type != UNKNOWN);
+ }
+
+ void identify() {
+ type = CLIENT;
}
void identify(Common::HostInfo *info) {
+ type = DAEMON;
hostInfo = info;
}
};
@@ -89,17 +93,13 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa
std::string x509TrustFile, x509CrlFile, x509CertFile, x509KeyFile;
- std::vector<Net::IPAddress> listenerAddresses;
- std::list<Net::Listener*> listeners;
+ std::vector<boost::asio::ip::tcp::endpoint> listenerAddresses;
+ std::list<boost::shared_ptr<Net::Listener> > listeners;
- std::set<Connection*> connections;
+ std::set<boost::shared_ptr<ServerConnection> > connections;
std::map<std::string,Common::HostInfo> daemonInfo;
- // Prevent shallow copy
- ConnectionManager(const ConnectionManager &o);
- ConnectionManager& operator=(const ConnectionManager &o);
-
void updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state);
void updateStateFinished(const Common::Request&) {
// TODO Error handling (updateStateFinished)
@@ -107,8 +107,8 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa
ConnectionManager() {}
- void newConnectionHandler(Net::ServerConnection *con);
- void disconnectHandler(Connection *con);
+ void handleNewConnection(boost::shared_ptr<Net::Connection> con);
+ void handleDisconnect(boost::shared_ptr<ServerConnection> con);
protected:
virtual bool handleConfigEntry(const Common::ConfigEntry &entry, bool handled);
@@ -122,10 +122,12 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa
return &connectionManager;
}
- Common::Connection* getDaemonConnection(const std::string &name) const throw (Net::Exception&);
+ boost::shared_ptr<Common::Connection> getDaemonConnection(const std::string &name) const throw (Net::Exception&);
std::string getDaemonName(const Common::Connection *con) const throw (Net::Exception&);
void identifyDaemonConnection(Common::Connection *con, const std::string &name) throw (Net::Exception&);
+ void identifyClientConnection(Common::Connection *con) throw (Net::Exception&);
+
std::vector<Common::HostInfo> getDaemonList() const;
};
diff --git a/src/Server/RequestHandlers/CMakeLists.txt b/src/Server/RequestHandlers/CMakeLists.txt
index 2c6305d..e340a1e 100644
--- a/src/Server/RequestHandlers/CMakeLists.txt
+++ b/src/Server/RequestHandlers/CMakeLists.txt
@@ -3,7 +3,7 @@ include_directories(${INCLUDES})
add_library(ServerRequestHandlers
DaemonCommandRequestHandler.cpp DaemonFSInfoRequestHandler.cpp
DaemonListRequestHandler.cpp DaemonStatusRequestHandler.cpp
- GSSAPIAuthRequestHandler.cpp IdentifyRequestHandler.cpp
+ IdentifyRequestHandler.cpp
LogRequestHandler.cpp UserInfoRequestHandler.cpp
UserListRequestHandler.cpp
)
diff --git a/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp b/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp
index 17a7e5d..e4b511e 100644
--- a/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp
+++ b/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp
@@ -46,8 +46,8 @@ void DaemonCommandRequestHandler::handlePacket(const Common::XmlPacket &packet)
std::string command = packet["command"];
try {
- Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]);
- Common::RequestManager::get()->sendRequest<Requests::CommandRequest>(daemonCon,
+ boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]);
+ Common::RequestManager::get()->sendRequest<Requests::CommandRequest>(daemonCon.get(),
boost::bind(&DaemonCommandRequestHandler::requestFinished, this, _1), command == "reboot");
}
catch(Net::Exception &e) {
diff --git a/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp b/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp
index df57a94..11a3f09 100644
--- a/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp
+++ b/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp
@@ -44,8 +44,8 @@ void DaemonFSInfoRequestHandler::handlePacket(const Common::XmlPacket &packet) {
// TODO Require authentication
try {
- Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]);
- Common::RequestManager::get()->sendRequest<Common::Requests::FSInfoRequest>(daemonCon,
+ boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]);
+ Common::RequestManager::get()->sendRequest<Common::Requests::FSInfoRequest>(daemonCon.get(),
boost::bind(&DaemonFSInfoRequestHandler::requestFinished, this, _1));
}
catch(Net::Exception &e) {
diff --git a/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp b/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp
index 3d99c57..c1e2fcd 100644
--- a/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp
+++ b/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp
@@ -46,8 +46,8 @@ void DaemonStatusRequestHandler::handlePacket(const Common::XmlPacket &packet) {
std::string daemonName = packet["daemonName"];
try {
- Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(daemonName);
- Common::RequestManager::get()->sendRequest<Common::Requests::StatusRequest>(daemonCon,
+ boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(daemonName);
+ Common::RequestManager::get()->sendRequest<Common::Requests::StatusRequest>(daemonCon.get(),
boost::bind(&DaemonStatusRequestHandler::requestFinished, this, _1));
}
catch(Net::Exception &e) {
diff --git a/src/Server/RequestHandlers/IdentifyRequestHandler.cpp b/src/Server/RequestHandlers/IdentifyRequestHandler.cpp
index f69b3f5..b47b505 100644
--- a/src/Server/RequestHandlers/IdentifyRequestHandler.cpp
+++ b/src/Server/RequestHandlers/IdentifyRequestHandler.cpp
@@ -42,7 +42,10 @@ void IdentifyRequestHandler::handlePacket(const Common::XmlPacket &packet) {
// TODO Require authentication
try {
- ConnectionManager::get()->identifyDaemonConnection(getConnection(), packet["hostname"]);
+ if(packet["hostname"].isEmpty())
+ ConnectionManager::get()->identifyClientConnection(getConnection());
+ else
+ ConnectionManager::get()->identifyDaemonConnection(getConnection(), packet["hostname"]);
Common::XmlPacket ret;
ret.setType("OK");
diff --git a/src/mad-server.conf b/src/mad-server.conf
index bd4672c..45927cf 100644
--- a/src/mad-server.conf
+++ b/src/mad-server.conf
@@ -2,7 +2,7 @@ Logger Console
Logger File "mad-server.log"
-Listen *
+#Listen *
X509TrustFile ../Cert/ca-cert.pem
diff --git a/src/mad-server.cpp b/src/mad-server.cpp
index 6fb0e21..d55c049 100644
--- a/src/mad-server.cpp
+++ b/src/mad-server.cpp
@@ -22,18 +22,10 @@
#include "Net/ThreadManager.h"
#include "Server/ConnectionManager.h"
-#include <signal.h>
-
using namespace Mad;
int main() {
- sigset_t signals;
-
- sigemptyset(&signals);
- sigaddset(&signals, SIGPIPE);
- sigprocmask(SIG_BLOCK, &signals, 0);
-
Net::ThreadManager::get()->init();
Server::ConnectionManager::get()->init();
diff --git a/src/mad.cpp b/src/mad.cpp
index 035c0be..22a3b44 100644
--- a/src/mad.cpp
+++ b/src/mad.cpp
@@ -17,8 +17,6 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#include "Net/FdManager.h"
-#include "Net/IPAddress.h"
#include "Net/ThreadManager.h"
#include "Common/ConfigManager.h"
#include "Common/LogManager.h"
@@ -26,10 +24,10 @@
#include "Common/ModuleManager.h"
#include "Common/RequestManager.h"
#include "Common/ClientConnection.h"
+#include "Common/Requests/IdentifyRequest.h"
#include "Common/RequestHandlers/FSInfoRequestHandler.h"
#include "Common/RequestHandlers/StatusRequestHandler.h"
#include "Daemon/Backends/NetworkLogger.h"
-#include "Daemon/Requests/IdentifyRequest.h"
#include "Daemon/RequestHandlers/CommandRequestHandler.h"
#include <unistd.h>
@@ -57,7 +55,8 @@ int main() {
Common::ClientConnection *connection = new Common::ClientConnection;
try {
- connection->connect(Net::IPAddress("127.0.0.1"), true);
+ connection->connect(boost::asio::ip::tcp::endpoint(
+ boost::asio::ip::address_v4::from_string("127.0.0.1"), 6666));
while(connection->isConnecting())
usleep(100000);
@@ -71,7 +70,7 @@ int main() {
//char hostname[256];
//gethostname(hostname, sizeof(hostname));
//Common::RequestManager::get()->sendRequest<Daemon::Requests::IdentifyRequest>(connection, sigc::ptr_fun(requestFinished), hostname);
- Common::RequestManager::get()->sendRequest<Daemon::Requests::IdentifyRequest>(connection, &requestFinished, "test");
+ Common::RequestManager::get()->sendRequest<Common::Requests::IdentifyRequest>(connection, &requestFinished, "test");
while(connection->isConnected())
usleep(100000);
diff --git a/src/madc.cpp b/src/madc.cpp
index 106617c..a902144 100644
--- a/src/madc.cpp
+++ b/src/madc.cpp
@@ -17,11 +17,10 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#include "Net/FdManager.h"
-#include "Net/IPAddress.h"
#include "Net/ThreadManager.h"
#include "Common/ClientConnection.h"
#include "Common/ConfigManager.h"
+#include "Common/Requests/IdentifyRequest.h"
#include "Common/Logger.h"
#include "Common/RequestManager.h"
#include "Client/CommandParser.h"
@@ -35,37 +34,14 @@
using namespace Mad;
-static void usage(const std::string &cmd) {
- std::cerr << "Usage: " << cmd << " address[:port]" << std::endl;
-}
-
-static void handleCommand(char *cmd) {
- if(!cmd)
- Client::CommandParser::get()->requestDisconnect();
- else if(!*cmd)
- return;
- else {
- Client::CommandParser::get()->parse(cmd);
- add_history(cmd);
- }
-
- if(Client::CommandManager::get()->requestsActive()) {
- rl_callback_handler_remove();
- Net::FdManager::get()->setFdEvents(STDIN_FILENO, 0);
- }
-}
+static bool commandRunning = false;
-static void charHandler(short events) {
- if(events & POLLIN)
- rl_callback_read_char();
+static void usage(const std::string &cmd) {
+ std::cerr << "Usage: " << cmd << " address" << std::endl;
}
-static void activateReadline() {
- if(Client::CommandManager::get()->willDisconnect())
- return;
-
- rl_callback_handler_install("mad> ", handleCommand);
- Net::FdManager::get()->setFdEvents(STDIN_FILENO, POLLIN);
+static void commandFinished() {
+ commandRunning = false;
}
int main(int argc, char *argv[]) {
@@ -82,17 +58,22 @@ int main(int argc, char *argv[]) {
Common::ClientConnection *connection = new Common::ClientConnection;
try {
- connection->connect(Net::IPAddress(argv[1]));
+ connection->connect(boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(argv[1]), 6666));
std::cerr << "Connecting to " << argv[1] << "..." << std::flush;
while(connection->isConnecting())
usleep(100000);
- std::cerr << " connected." << std::endl;
-
Common::RequestManager::get()->registerConnection(connection);
+ commandRunning = true;
+ Common::RequestManager::get()->sendRequest<Common::Requests::IdentifyRequest>(connection, boost::bind(&commandFinished));
+ while(commandRunning) {
+ usleep(100000);
+ }
+
+ std::cerr << " connected." << std::endl;
std::cerr << "Receiving host list..." << std::flush;
Client::InformationManager::get()->updateDaemonList(connection);
@@ -103,16 +84,24 @@ int main(int argc, char *argv[]) {
std::cerr << " done." << std::endl << std::endl;
Client::CommandParser::get()->setConnection(connection);
- Client::CommandManager::get()->signalFinished().connect(&activateReadline);
-
- Net::FdManager::get()->registerFd(STDIN_FILENO,&charHandler);
-
- activateReadline();
-
- while(connection->isConnected())
- usleep(100000);
-
- Net::FdManager::get()->unregisterFd(STDIN_FILENO);
+ Client::CommandManager::get()->signalFinished().connect(&commandFinished);
+
+ while(connection->isConnected()) {
+ char *cmd = readline("mad> ");
+
+ if(!cmd) {
+ commandRunning = true;
+ Client::CommandParser::get()->requestDisconnect();
+ }
+ else if(*cmd) {
+ commandRunning = true;
+ Client::CommandParser::get()->parse(cmd);
+ add_history(cmd);
+ }
+
+ while(Client::CommandManager::get()->requestsActive())
+ usleep(100000);
+ }
Common::RequestManager::get()->unregisterConnection(connection);
}