diff options
40 files changed, 408 insertions, 1416 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index a16eff8..4f422f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,11 +4,11 @@ project(MAD) set(CMAKE_MODULE_PATH ${MAD_SOURCE_DIR}) find_package(LibXml2 REQUIRED) -find_package(GnuTLS REQUIRED) find_package(LTDL REQUIRED) find_package(Readline REQUIRED) find_package(KRB5 REQUIRED gssapi) -find_package(Boost REQUIRED regex signals thread) +find_package(OpenSSL REQUIRED) +find_package(Boost REQUIRED regex signals system thread) configure_file(config.h.in ${MAD_BINARY_DIR}/config.h) @@ -20,6 +20,8 @@ set(INCLUDES ${LTDL_INCLUDE_DIR} ${READLINE_INCLUDE_DIR} ${KRB5_INCLUDE_DIRS} - ${Boost_INCLUDE_DIR}) + ${OPENSSL_INCLUDE_DIR} + ${Boost_INCLUDE_DIR} +) add_subdirectory(src) diff --git a/FindGnuTLS.cmake b/FindGnuTLS.cmake deleted file mode 100644 index 869444b..0000000 --- a/FindGnuTLS.cmake +++ /dev/null @@ -1,38 +0,0 @@ -INCLUDE( FindPkgConfig ) - -IF ( GNUTLS_FIND_REQUIRED ) - SET( _pkgconfig_REQUIRED "REQUIRED" ) -ELSE( GNUTLS_FIND_REQUIRED ) - SET( _pkgconfig_REQUIRED "" ) -ENDIF ( GNUTLS_FIND_REQUIRED ) - -IF ( GNUTLS_MIN_VERSION ) - PKG_SEARCH_MODULE( GNUTLS ${_pkgconfig_REQUIRED} gnutls>=${GNUTLS_MIN_VERSION} ) -ELSE ( GNUTLS_MIN_VERSION ) - PKG_SEARCH_MODULE( GNUTLS ${_pkgconfig_REQUIRED} gnutls ) -ENDIF ( GNUTLS_MIN_VERSION ) - - -IF( NOT GNUTLS_FOUND AND NOT PKG_CONFIG_FOUND ) - FIND_PATH( GNUTLS_INCLUDE_DIRS gnutls/gnutls.h ) - FIND_LIBRARY( GNUTLS_LIBRARIES gnutls) - - # Report results - IF ( GNUTLS_LIBRARIES AND GNUTLS_INCLUDE_DIRS ) - SET( GNUTLS_FOUND 1 ) - IF ( NOT GNUTLS_FIND_QUIETLY ) - MESSAGE( STATUS "Found gnutls: ${GNUTLS_LIBRARIES}" ) - ENDIF ( NOT GNUTLS_FIND_QUIETLY ) - ELSE ( GNUTLS_LIBRARIES AND GNUTLS_INCLUDE_DIRS ) - IF ( GNUTLS_FIND_REQUIRED ) - MESSAGE( SEND_ERROR "Could NOT find gnutls" ) - ELSE ( GNUTLS_FIND_REQUIRED ) - IF ( NOT GNUTLS_FIND_QUIETLY ) - MESSAGE( STATUS "Could NOT find gnutls" ) - ENDIF ( NOT GNUTLS_FIND_QUIETLY ) - ENDIF ( GNUTLS_FIND_REQUIRED ) - ENDIF ( GNUTLS_LIBRARIES AND GNUTLS_INCLUDE_DIRS ) -ENDIF( NOT GNUTLS_FOUND AND NOT PKG_CONFIG_FOUND ) - -MARK_AS_ADVANCED( GNUTLS_LIBRARIES GNUTLS_INCLUDE_DIRS ) - diff --git a/src/Common/ActionManager.cpp b/src/Common/ActionManager.cpp deleted file mode 100644 index fc3c034..0000000 --- a/src/Common/ActionManager.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - * ActionManager.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "ActionManager.h" -#include <Net/FdManager.h> - -#include <fcntl.h> -#include <signal.h> - -namespace Mad { -namespace Common { - -ActionManager ActionManager::actionManager; - - -void ActionManager::doInit() { - // TODO Error handling - - pipe(notifyPipe); - - fcntl(notifyPipe[0], F_SETFL, fcntl(notifyPipe[0], F_GETFL) | O_NONBLOCK); - fcntl(notifyPipe[1], F_SETFL, fcntl(notifyPipe[1], F_GETFL) | O_NONBLOCK); - - Net::FdManager::get()->registerFd(notifyPipe[0], boost::bind(&ActionManager::run, this), POLLIN); -} - -void ActionManager::doDeinit() { - Net::FdManager::get()->unregisterFd(notifyPipe[0]); - - close(notifyPipe[0]); - close(notifyPipe[1]); -} - - -void ActionManager::run() { - // Empty pipe - char buf[16]; - while(read(notifyPipe[0], buf, sizeof(buf)) > 0) {} - - while(true) { - sigset_t set, oldset; - sigfillset(&set); - sigprocmask(SIG_SETMASK, &set, &oldset); - - if(actions.empty()) { - sigprocmask(SIG_SETMASK, &oldset, 0); - return; - } - - boost::function0<void> action = actions.front(); - actions.pop(); - - sigprocmask(SIG_SETMASK, &oldset, 0); - - action(); - } -} - -void ActionManager::add(const boost::function0<void> &action) { - sigset_t set, oldset; - sigfillset(&set); - sigprocmask(SIG_SETMASK, &set, &oldset); - - actions.push(action); - write(notifyPipe[1], "", 1); - - sigprocmask(SIG_SETMASK, &oldset, 0); -} - -} -} diff --git a/src/Common/ActionManager.h b/src/Common/ActionManager.h deleted file mode 100644 index 5d3dc15..0000000 --- a/src/Common/ActionManager.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ActionManager.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_COMMON_ACTIONMANAGER_H_ -#define MAD_COMMON_ACTIONMANAGER_H_ - -#include "Initializable.h" - -#include <queue> -#include <unistd.h> - -#include <boost/function.hpp> - -namespace Mad { -namespace Common { - -class ActionManager : public Initializable { - private: - std::queue<boost::function0<void> > actions; - int notifyPipe[2]; - - static ActionManager actionManager; - - ActionManager() {} - - protected: - void doInit(); - void doDeinit(); - - public: - void run(); - void add(const boost::function0<void> &action); - - static ActionManager *get() { - if(!actionManager.isInitialized()) - actionManager.init(); - - return &actionManager; - } -}; - -} -} - -#endif /* MAD_COMMON_ACTIONMANAGER_H_ */ diff --git a/src/Common/CMakeLists.txt b/src/Common/CMakeLists.txt index e2f7b4b..fe18760 100644 --- a/src/Common/CMakeLists.txt +++ b/src/Common/CMakeLists.txt @@ -6,7 +6,7 @@ include_directories(${INCLUDES}) link_directories(${LTDL_LIBRARY_DIR}) add_library(Common - ActionManager.cpp Base64Encoder.cpp ClientConnection.cpp ConfigEntry.cpp + Base64Encoder.cpp ClientConnection.cpp ConfigEntry.cpp ConfigManager.cpp Connection.cpp Initializable.cpp Logger.cpp LogManager.cpp ModuleManager.cpp Request.cpp RequestManager.cpp SystemManager.cpp Tokenizer.cpp XmlPacket.cpp diff --git a/src/Common/ClientConnection.cpp b/src/Common/ClientConnection.cpp index 381f822..d061c8d 100644 --- a/src/Common/ClientConnection.cpp +++ b/src/Common/ClientConnection.cpp @@ -33,8 +33,8 @@ bool ClientConnection::send(const Net::Packet &packet) { return connection->send(packet); } -void ClientConnection::connect(const Net::IPAddress &address, bool daemon) throw(Net::Exception) { - connection->connect(address, daemon); +void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Net::Exception) { + connection->connect(address); } bool ClientConnection::isConnecting() const { @@ -50,7 +50,7 @@ bool ClientConnection::disconnect() { return true; } -void* ClientConnection::getCertificate(size_t *size) const { +/*void* ClientConnection::getCertificate(size_t *size) const { const gnutls_datum_t *cert = connection->getCertificate(); *size = cert->size; @@ -62,7 +62,7 @@ void* ClientConnection::getPeerCertificate(size_t *size) const { *size = cert->size; return cert->data; -} +}*/ } } diff --git a/src/Common/ClientConnection.h b/src/Common/ClientConnection.h index 09ca4db..4710bd4 100644 --- a/src/Common/ClientConnection.h +++ b/src/Common/ClientConnection.h @@ -23,11 +23,12 @@ #include "Connection.h" #include <Net/Exception.h> +#include <boost/asio.hpp> + namespace Mad { namespace Net { class ClientConnection; -class IPAddress; } namespace Common { @@ -43,14 +44,14 @@ class ClientConnection : public Connection { ClientConnection(); virtual ~ClientConnection() {} - void connect(const Net::IPAddress &address, bool daemon = false) throw(Net::Exception); + void connect(const boost::asio::ip::tcp::endpoint &address) throw(Net::Exception); bool isConnecting() const; bool isConnected() const; virtual bool disconnect(); - virtual void* getCertificate(size_t *size) const; - virtual void* getPeerCertificate(size_t *size) const; + //virtual void* getCertificate(size_t *size) const; + //virtual void* getPeerCertificate(size_t *size) const; }; } diff --git a/src/Common/Connection.h b/src/Common/Connection.h index ea90c4c..0cfc742 100644 --- a/src/Common/Connection.h +++ b/src/Common/Connection.h @@ -62,8 +62,8 @@ class Connection { } virtual bool disconnect() = 0; - virtual void* getCertificate(size_t *size) const = 0; - virtual void* getPeerCertificate(size_t *size) const = 0; + //virtual void* getCertificate(size_t *size) const = 0; + //virtual void* getPeerCertificate(size_t *size) const = 0; virtual diff --git a/src/Common/Requests/CMakeLists.txt b/src/Common/Requests/CMakeLists.txt index 1b632e2..1d49f0e 100644 --- a/src/Common/Requests/CMakeLists.txt +++ b/src/Common/Requests/CMakeLists.txt @@ -1,6 +1,6 @@ include_directories(${INCLUDES}) add_library(Requests - DisconnectRequest.cpp GSSAPIAuthRequest.cpp SimpleRequest.cpp UserInfoRequest.cpp + DisconnectRequest.cpp IdentifyRequest.cpp SimpleRequest.cpp UserInfoRequest.cpp ) target_link_libraries(Requests ${KRB5_LIBRARIES}) diff --git a/src/Daemon/Requests/IdentifyRequest.cpp b/src/Common/Requests/IdentifyRequest.cpp index dd035bf..6cf09d1 100644 --- a/src/Daemon/Requests/IdentifyRequest.cpp +++ b/src/Common/Requests/IdentifyRequest.cpp @@ -19,16 +19,16 @@ #include "IdentifyRequest.h" -#include <Common/Logger.h> - namespace Mad { -namespace Daemon { +namespace Common { namespace Requests { void IdentifyRequest::sendRequest() { Common::XmlPacket packet; packet.setType("Identify"); - packet.add("hostname", hostname); + + if(!hostname.empty()) + packet.add("hostname", hostname); sendPacket(packet); } diff --git a/src/Daemon/Requests/IdentifyRequest.h b/src/Common/Requests/IdentifyRequest.h index 51a30e3..3798909 100644 --- a/src/Daemon/Requests/IdentifyRequest.h +++ b/src/Common/Requests/IdentifyRequest.h @@ -20,11 +20,12 @@ #ifndef MAD_DAEMON_REQUESTS_IDENTIFYREQUEST_H_ #define MAD_DAEMON_REQUESTS_IDENTIFYREQUEST_H_ -#include <Common/Request.h> +#include "../Request.h" + #include <string> namespace Mad { -namespace Daemon { +namespace Common { namespace Requests { class IdentifyRequest : public Common::Request { @@ -35,7 +36,7 @@ class IdentifyRequest : public Common::Request { virtual void sendRequest(); public: - IdentifyRequest(Common::Connection *connection, uint16_t requestId, slot_type slot, const std::string &hostname0) + IdentifyRequest(Common::Connection *connection, uint16_t requestId, slot_type slot, const std::string &hostname0 = std::string()) : Common::Request(connection, requestId, slot), hostname(hostname0) {} }; diff --git a/src/Daemon/Requests/CMakeLists.txt b/src/Daemon/Requests/CMakeLists.txt index 99fbd45..74d8072 100644 --- a/src/Daemon/Requests/CMakeLists.txt +++ b/src/Daemon/Requests/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${INCLUDES}) add_library(DaemonRequests - IdentifyRequest.cpp LogRequest.cpp + LogRequest.cpp ) diff --git a/src/Net/CMakeLists.txt b/src/Net/CMakeLists.txt index aa3f857..fae358b 100644 --- a/src/Net/CMakeLists.txt +++ b/src/Net/CMakeLists.txt @@ -2,7 +2,7 @@ include_directories(${INCLUDES}) link_directories(${Boost_LIBRARY_DIRS}) add_library(Net - ClientConnection.cpp Connection.cpp Exception.cpp FdManager.cpp IPAddress.cpp - Listener.cpp Packet.cpp ServerConnection.cpp ThreadManager.cpp + ClientConnection.cpp Connection.cpp Exception.cpp Listener.cpp + Packet.cpp ThreadManager.cpp ) -target_link_libraries(Net ${Boost_LIBRARIES} ${GNUTLS_LIBRARIES}) +target_link_libraries(Net ${Boost_LIBRARIES} ${OPENSSL_LIBRARIES}) diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 087d95f..9cdf796 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -18,99 +18,36 @@ */ #include "ClientConnection.h" -#include "FdManager.h" -#include "IPAddress.h" -#include <boost/thread/locks.hpp> - -#include <cstring> -#include <cerrno> -#include <sys/socket.h> -#include <fcntl.h> +#include <Common/Logger.h> namespace Mad { namespace Net { -// TODO Error handling -void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { - if(length != sizeof(ConnectionHeader)) - // Error... disconnect - return; - - const ConnectionHeader *header = (const ConnectionHeader*)(data); - - if(header->m != 'M' || header->a != 'A' || header->d != 'D') - // Error... disconnect +void ClientConnection::handleConnect(const boost::system::error_code& error) { + if(error) { + // TODO Error handling + doDisconnect(); return; + } - if(header->protVerMin != 1) - // Unsupported protocol... disconnect - return; - - enterReceiveLoop(); -} - -void ClientConnection::connectionHeader() { - ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1}; + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - rawSend((uint8_t*)&header, sizeof(header)); - rawReceive(sizeof(ConnectionHeader), boost::bind(&ClientConnection::connectionHeaderReceiveHandler, this, _1, _2)); + socket.async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error)); } -void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exception) { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - daemon = daemon0; +void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception) { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); if(_isConnected()) { return; // TODO Error } - sock = socket(PF_INET, SOCK_STREAM, 0); - if(sock < 0) { - throw Exception("socket()", Exception::INTERNAL_ERRNO, errno); - } - - if(peer) - delete peer; - peer = new IPAddress(address); - - if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) { - close(sock); - delete peer; - peer = 0; - - throw Exception("connect()", Exception::INTERNAL_ERRNO, errno); - } - - // Set non-blocking flag - int flags = fcntl(sock, F_GETFL, 0); - - if(flags < 0) { - close(sock); - - throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno); - } - - fcntl(sock, F_SETFL, flags | O_NONBLOCK); - - // Don't linger - struct linger linger = {1, 0}; - setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); - - gnutls_init(&session, GNUTLS_CLIENT); - gnutls_set_default_priority(session); - gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); - gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - - FdManager::get()->registerFd(sock, boost::bind(&ClientConnection::sendReceive, this, _1)); - + peer = address; state = CONNECT; - lock.unlock(); - - updateEvents(); + socket.lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index bdd7872..f93c2fc 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -23,24 +23,27 @@ #include "Connection.h" #include "Exception.h" +#include <boost/utility/base_from_member.hpp> + + namespace Mad { namespace Net { class IPAddress; -class ClientConnection : public Connection { +class ClientConnection : private boost::base_from_member<boost::asio::ssl::context>, public Connection { private: - bool daemon; - - void connectionHeaderReceiveHandler(const void *data, unsigned long length); - - protected: - virtual void connectionHeader(); + void handleConnect(const boost::system::error_code& error); public: - ClientConnection() : daemon(0) {} - - void connect(const IPAddress &address, bool daemon0 = false) throw(Exception); + ClientConnection() + : boost::base_from_member<boost::asio::ssl::context>(boost::ref(Connection::ioService), boost::asio::ssl::context::sslv23), + Connection(member) + { + member.set_verify_mode(boost::asio::ssl::context::verify_none); + } + + void connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception); }; } diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 4e5029f..b9691cb 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -18,361 +18,225 @@ */ #include "Connection.h" -#include "FdManager.h" -#include "IPAddress.h" #include "ThreadManager.h" -#include <cstring> -#include <sys/socket.h> +#include <Common/Logger.h> +#include <cstring> #include <boost/bind.hpp> namespace Mad { namespace Net { - -Connection::StaticInit Connection::staticInit; +boost::asio::io_service Connection::ioService; Connection::~Connection() { if(_isConnected()) doDisconnect(); - - if(transR.data) - delete [] transR.data; - - while(!_sendQueueEmpty()) { - delete [] transS.front().data; - transS.pop(); - } - - gnutls_certificate_free_credentials(x509_cred); - - if(peer) - delete peer; -} - -void Connection::handshake() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != CONNECT) - return; - - state = HANDSHAKE; - lock.unlock(); - - doHandshake(); } -void Connection::bye() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != DISCONNECT) - return; - - state = BYE; - lock.unlock(); +void Connection::handleHandshake(const boost::system::error_code& error) { + if(error) { + Common::Logger::logf("Error: %s", error.message().c_str()); - doBye(); -} - -void Connection::doHandshake() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); - if(state != HANDSHAKE) + // TODO Error handling + doDisconnect(); return; + } - int ret = gnutls_handshake(session); - if(ret < 0) { - lock.unlock(); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + state = CONNECTED; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } + receiving = false; + sending = 0; - // TODO: Error - doDisconnect(); - return; + received = 0; } - state = CONNECTION_HEADER; - lock.unlock(); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); - connectionHeader(); + enterReceiveLoop(); } -void Connection::doBye() { - if(state != BYE) - return; - - int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR); - if(ret < 0) { - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } +void Connection::handleShutdown(const boost::system::error_code& error) { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - // TODO: Error - doDisconnect(); - return; + if(error) { + // TODO Error } - doDisconnect(); + state = DISCONNECTED; + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); } void Connection::enterReceiveLoop() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(!_isConnected() || _isDisconnecting()) - return; - - if(_isConnecting()) - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - state = PACKET_HEADER; - lock.unlock(); + if(!_isConnected() || _isDisconnecting()) + return; + } - rawReceive(sizeof(Packet::Data), boost::bind(&Connection::packetHeaderReceiveHandler, this, _1, _2)); + rawReceive(sizeof(Packet::Data), boost::bind(&Connection::handleHeaderReceive, this, _1)); } -void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_HEADER) - return; +void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) { + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - if(length != sizeof(Packet::Data)) { - // TODO: Error - doDisconnect(); - return; + header = *reinterpret_cast<const Packet::Data*>(data.data()); } - header = *(const Packet::Data*)data; - if(header.length == 0) { ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId)))); enterReceiveLoop(); } else { - state = PACKET_DATA; - rawReceive(ntohs(header.length), boost::bind(&Connection::packetDataReceiveHandler, this, _1, _2)); + rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, this, _1)); } } -void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_DATA) - return; +void Connection::handleDataReceive(const std::vector<boost::uint8_t> &data) { + { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - if(length != ntohs(header.length)) { - // TODO: Error - doDisconnect(); - return; + Packet packet(ntohs(header.requestId), data.data(), ntohs(header.length)); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, packet)); } - ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId), data, length))); - enterReceiveLoop(); } -void Connection::doReceive() { - if(!isConnected()) - return; - - boost::unique_lock<boost::mutex> lock(receiveLock); - - if(_receiveComplete()) - return; - - ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); +void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + if(error || (bytes_transferred+received) < length) { + Common::Logger::logf(Common::Logger::VERBOSE, "Read error: %s", error.message().c_str()); - if(ret < 0) { - lock.unlock(); - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; - - // TODO: Error + // TODO Error doDisconnect(); return; } - transR.transmitted += ret; - - if(_receiveComplete()) { - // Save data pointer, as transR.notify might start a new reception - uint8_t *data = transR.data; - transR.data = 0; + std::vector<boost::uint8_t> buffer; - lock.unlock(); + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); - transR.notify(data, transR.length); + if(state != CONNECTED || !receiving) + return; - delete [] data; - } - else { - lock.unlock(); + buffer.insert(buffer.end(), receiveBuffer.data(), receiveBuffer.data()+length); } - updateEvents(); -} - -bool Connection::rawReceive(unsigned long length, - const boost::function2<void,const void*,unsigned long> ¬ify) -{ - if(!isConnected()) - return false; - - boost::unique_lock<boost::mutex> lock(receiveLock); - if(!_receiveComplete()) - return false; - - transR.data = new uint8_t[length]; - transR.length = length; - transR.transmitted = 0; - transR.notify = notify; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - lock.unlock(); + receiving = false; + received = received + bytes_transferred - length; - updateEvents(); + if(received) + std::memmove(receiveBuffer.data(), receiveBuffer.data()+length, received); + } - return true; + notify(buffer); } -void Connection::doSend() { - if(!isConnected()) - return; +void Connection::rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - boost::unique_lock<boost::mutex> lock(sendLock); - while(!_sendQueueEmpty()) { - ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, - transS.front().length-transS.front().transmitted); - - if(ret < 0) { - lock.unlock(); + if(!_isConnected()) + return; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - // TODO: Error - doDisconnect(); + if(receiving) return; - } - transS.front().transmitted += ret; + receiving = true; - if(transS.front().transmitted == transS.front().length) { - delete [] transS.front().data; - transS.pop(); + if(length > received) { + boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer.data()+received, receiveBuffer.size()-received), boost::asio::transfer_at_least(length), + boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, + length, notify)); + + return; } } lock.unlock(); - updateEvents(); + handleRead(boost::system::error_code(), 0, length, notify); } -bool Connection::rawSend(const uint8_t *data, unsigned long length) { - if(!isConnected()) - return false; +void Connection::handleWrite(const boost::system::error_code& error, std::size_t) { + { + boost::unique_lock<boost::shared_mutex> lock(connectionLock); - Transmission trans = {length, 0, new uint8_t[length], boost::function2<void,const void*,unsigned long>()}; - std::memcpy(trans.data, data, length); + sending--; - sendLock.lock(); - transS.push(trans); - sendLock.unlock(); - - updateEvents(); + if(state == DISCONNECT && !sending) { + lock.unlock(); + doDisconnect(); + return; + } + } - return true; -} + if(error) { + Common::Logger::logf(Common::Logger::VERBOSE, "Write error: %s", error.message().c_str()); -void Connection::sendReceive(short events) { - if(events & POLLHUP || events & POLLERR) { + // TODO Error doDisconnect(); - return; } +} - switch(state) { - case CONNECT: - handshake(); - return; - case HANDSHAKE: - doHandshake(); - return; - case DISCONNECT: - if(!_sendQueueEmpty()) - break; +void Connection::rawSend(const uint8_t *data, std::size_t length) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - bye(); - return; - case BYE: - doBye(); - return; - default: - break; - } + if(!_isConnected()) + return; - if(events & POLLIN) - doReceive(); + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - if(events & POLLOUT) - doSend(); + sending++; + boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } } bool Connection::send(const Packet &packet) { - stateLock.lock_shared(); - bool err = (!_isConnected() || _isConnecting() || _isDisconnecting()); - stateLock.unlock_shared(); - - if(err) - return false; + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isConnecting() || _isDisconnecting()) + return false; + } - return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + return true; } void Connection::disconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(!_isConnected() || _isDisconnecting()) - return; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isDisconnecting()) + return; - state = DISCONNECT; + state = DISCONNECT; - lock.unlock(); + if(sending) + return; + } - updateEvents(); + doDisconnect(); } void Connection::doDisconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(_isConnected()) { - FdManager::get()->unregisterFd(sock); - - shutdown(sock, SHUT_RDWR); - close(sock); + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - gnutls_deinit(session); - - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); - - state = DISCONNECTED; - } -} - -void Connection::updateEvents() { - receiveLock.lock(); - short events = (_receiveComplete() ? 0 : POLLIN); - receiveLock.unlock(); - - sendLock.lock(); - events |= (_sendQueueEmpty() ? 0 : POLLOUT); - sendLock.unlock(); - - stateLock.lock_shared(); - if(state == HANDSHAKE || state == BYE) - events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); - else if(state == CONNECT || state == DISCONNECT) - events |= POLLOUT; - - FdManager::get()->setFdEvents(sock, events); - stateLock.unlock_shared(); + if(_isConnected()) + socket.async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index a0b95ea..303485d 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -24,155 +24,111 @@ #include "Packet.h" -#include <queue> -#include <string> -#include <gnutls/gnutls.h> -#include <poll.h> +#include <boost/asio.hpp> +#include <boost/asio/ssl.hpp> #include <boost/signal.hpp> -#include <boost/thread/mutex.hpp> -#include <boost/thread/shared_mutex.hpp> -#include <iostream> +#include <boost/thread/shared_mutex.hpp> namespace Mad { namespace Net { -class IPAddress; -class Packet; +class ThreadManager; -class Connection { +class Connection : boost::noncopyable { private: - class StaticInit { - public: - StaticInit() { - gnutls_global_init(); - } + friend class ThreadManager; - ~StaticInit() { - gnutls_global_deinit(); - } - }; - static StaticInit staticInit; + class Buffer { + public: + Buffer(const uint8_t *data0, std::size_t length) : data(new std::vector<uint8_t>(data0, data0+length)), buffer(boost::asio::buffer(*data)) {} - struct Transmission { - unsigned long length; - unsigned long transmitted; + typedef boost::asio::const_buffer value_type; + typedef const boost::asio::const_buffer* const_iterator; - uint8_t *data; + const boost::asio::const_buffer* begin() const { return &buffer; } + const boost::asio::const_buffer* end() const { return &buffer + 1; } - boost::function2<void,const void*,unsigned long> notify; + private: + boost::shared_ptr<std::vector<uint8_t> > data; + boost::asio::const_buffer buffer; }; - boost::mutex receiveLock; - Transmission transR; - boost::mutex sendLock; - std::queue<Transmission> transS; + std::vector<boost::uint8_t> receiveBuffer; + std::size_t received; Packet::Data header; - boost::signal1<void,const Packet&> receiveSignal; + boost::signal1<void, const Packet&> receiveSignal; boost::signal0<void> connectedSignal; boost::signal0<void> disconnectedSignal; - void doHandshake(); - - void packetHeaderReceiveHandler(const void *data, unsigned long length); - void packetDataReceiveHandler(const void *data, unsigned long length); + bool receiving; + unsigned long sending; - void doReceive(); - void doSend(); - - void doBye(); - - void doDisconnect(); + void enterReceiveLoop(); - bool _receiveComplete() const { - return (transR.length == transR.transmitted); - } + void handleHeaderReceive(const std::vector<boost::uint8_t> &data); + void handleDataReceive(const std::vector<boost::uint8_t> &data); - bool _sendQueueEmpty() const {return transS.empty();} + void handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); + void handleWrite(const boost::system::error_code& error, std::size_t); - void bye(); + void handleShutdown(const boost::system::error_code& error); - // Prevent shallow copy - Connection(const Connection &o); - Connection& operator=(const Connection &o); + void rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); + void rawSend(const uint8_t *data, std::size_t length); protected: - struct ConnectionHeader { - uint8_t m; - uint8_t a; - uint8_t d; - uint8_t type; - - uint8_t versionMajor; - uint8_t versionMinor; - uint8_t protVerMin; - uint8_t protVerMax; - }; + static boost::asio::io_service ioService; - boost::shared_mutex stateLock; + boost::shared_mutex connectionLock; enum State { - DISCONNECTED, CONNECT, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE + DISCONNECTED, CONNECT, CONNECTED, DISCONNECT } state; - int sock; - gnutls_session_t session; - gnutls_certificate_credentials_t x509_cred; - IPAddress *peer; + boost::asio::ssl::stream<boost::asio::ip::tcp::socket> socket; + boost::asio::ip::tcp::endpoint peer; - void handshake(); - - virtual void connectionHeader() = 0; - - bool rawReceive(unsigned long length, const boost::function2<void,const void*,unsigned long> ¬ify); - bool rawSend(const uint8_t *data, unsigned long length); - - void enterReceiveLoop(); - - void sendReceive(short events); + void handleHandshake(const boost::system::error_code& error); bool _isConnected() const {return (state != DISCONNECTED);} bool _isConnecting() const { - return (state == CONNECT || state == HANDSHAKE || state == CONNECTION_HEADER); + return (state == CONNECT); } bool _isDisconnecting() const { - return (state == DISCONNECT || state == BYE); + return (state == DISCONNECT); } - void updateEvents(); - - public: - Connection() : state(DISCONNECTED), peer(0) { - transR.length = transR.transmitted = 0; - transR.data = 0; + void doDisconnect(); - gnutls_certificate_allocate_credentials(&x509_cred); - } + Connection(boost::asio::ssl::context &sslContext) : + receiveBuffer(1024*1024), state(DISCONNECTED), socket(ioService, sslContext) {} + public: virtual ~Connection(); bool isConnected() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isConnected(); } bool isConnecting() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isConnecting(); } bool isDisconnecting() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isDisconnecting(); } - const gnutls_datum_t* getCertificate() const { + /*const gnutls_datum_t* getCertificate() const { // TODO Thread-safeness return gnutls_certificate_get_ours(session); } @@ -181,16 +137,18 @@ class Connection { // TODO Thread-safeness unsigned int n; return gnutls_certificate_get_peers(session, &n); - } + }*/ - // TODO Thread-safeness - const IPAddress* getPeer() const {return peer;} + boost::asio::ip::tcp::endpoint getPeer() { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); + return peer; + } void disconnect(); bool send(const Packet &packet); - boost::signal1<void,const Packet&>& signalReceive() {return receiveSignal;} + boost::signal1<void, const Packet&>& signalReceive() {return receiveSignal;} boost::signal0<void>& signalConnected() {return connectedSignal;} boost::signal0<void>& signalDisconnected() {return disconnectedSignal;} }; diff --git a/src/Net/Exception.cpp b/src/Net/Exception.cpp index 34b8033..e082948 100644 --- a/src/Net/Exception.cpp +++ b/src/Net/Exception.cpp @@ -20,7 +20,6 @@ #include "Exception.h" #include <cstring> -#include <gnutls/gnutls.h> namespace Mad { namespace Net { @@ -46,8 +45,6 @@ std::string Exception::strerror() const { return ret + "Not implemented"; case INTERNAL_ERRNO: return ret + std::strerror(subCode); - case INTERNAL_GNUTLS: - return ret + "GnuTLS error: " + gnutls_strerror(subCode); case INVALID_ADDRESS: return ret + "Invalid address"; case ALREADY_IDENTIFIED: diff --git a/src/Net/Exception.h b/src/Net/Exception.h index 48e86d1..8522528 100644 --- a/src/Net/Exception.h +++ b/src/Net/Exception.h @@ -29,7 +29,7 @@ class Exception { public: enum ErrorCode { SUCCESS = 0x0000, UNEXPECTED_PACKET = 0x0001, INVALID_ACTION = 0x0002, NOT_AVAILABLE = 0x0003, NOT_FINISHED = 0x0004, NOT_IMPLEMENTED = 0x0005, - INTERNAL_ERRNO = 0x0010, INTERNAL_GNUTLS = 0x0011, + INTERNAL_ERRNO = 0x0010, INVALID_ADDRESS = 0x0020, ALREADY_IDENTIFIED = 0x0030, UNKNOWN_DAEMON = 0x0031 }; diff --git a/src/Net/FdManager.cpp b/src/Net/FdManager.cpp deleted file mode 100644 index d8faef4..0000000 --- a/src/Net/FdManager.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* - * FdManager.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "FdManager.h" -#include "ThreadManager.h" - -#include <signal.h> -#include <unistd.h> -#include <sys/fcntl.h> - - -namespace Mad { -namespace Net { - -FdManager FdManager::fdManager; - - -FdManager::FdManager() : running(false) { - pipe(interruptPipe); - - int flags = fcntl(interruptPipe[0], F_GETFL, 0); - fcntl(interruptPipe[0], F_SETFL, flags | O_NONBLOCK); - - flags = fcntl(interruptPipe[1], F_GETFL, 0); - fcntl(interruptPipe[1], F_SETFL, flags | O_NONBLOCK); - - registerFd(interruptPipe[0], boost::bind(&FdManager::readInterrupt, this), POLLIN); -} - -FdManager::~FdManager() { - unregisterFd(interruptPipe[0]); - - close(interruptPipe[0]); - close(interruptPipe[1]); -} - - -bool FdManager::registerFd(int fd, const boost::function1<void, short> &handler, short events) { - struct pollfd pollfd = {fd, events, 0}; - - boost::lock(handlerLock, eventLock); - pollfds.insert(std::make_pair(fd, pollfd)); - - bool ret = handlers.insert(std::make_pair(fd, handler)).second; - - eventLock.unlock(); - handlerLock.unlock(); - - interrupt(); - - return ret; -} - -bool FdManager::unregisterFd(int fd) { - boost::lock(handlerLock, eventLock); - pollfds.erase(fd); - bool ret = handlers.erase(fd); - eventLock.unlock(); - handlerLock.unlock(); - - interrupt(); - - return ret; -} - -bool FdManager::setFdEvents(int fd, short events) { - boost::unique_lock<boost::shared_mutex> lock(eventLock); - - std::map<int, struct pollfd>::iterator pollfd = pollfds.find(fd); - - if(pollfd == pollfds.end()) - return false; - - if(pollfd->second.events != events) { - pollfd->second.events = events; - interrupt(); - } - - return true; -} - -short FdManager::getFdEvents(int fd) { - boost::shared_lock<boost::shared_mutex> lock(eventLock); - - std::map<int, struct pollfd>::const_iterator pollfd = pollfds.find(fd); - - if(pollfd == pollfds.end()) - return -1; - - return pollfd->second.events; -} - -void FdManager::readInterrupt() { - char buf[20]; - - while(read(interruptPipe[0], buf, sizeof(buf)) > 0) {} -} - -void FdManager::interrupt() { - char buf = 0; - - write(interruptPipe[1], &buf, sizeof(buf)); -} - -void FdManager::ioThread() { - runLock.lock(); - running = true; - runLock.unlock_and_lock_shared(); - - while(running) { - runLock.unlock_shared(); - - handlerLock.lock_shared(); - eventLock.lock_shared(); - readInterrupt(); - - size_t count = pollfds.size(); - struct pollfd *fdarray = new struct pollfd[count]; - - std::map<int, struct pollfd>::iterator pollfd = pollfds.begin(); - - for(size_t n = 0; n < count; ++n) { - fdarray[n] = pollfd->second; - ++pollfd; - } - - eventLock.unlock_shared(); - handlerLock.unlock_shared(); - - if(poll(fdarray, count, -1) > 0) { - handlerLock.lock_shared(); - - std::queue<boost::function0<void> > calls; - - for(size_t n = 0; n < count; ++n) { - if(fdarray[n].revents) - calls.push(boost::bind(handlers[fdarray[n].fd], fdarray[n].revents)); - } - - handlerLock.unlock_shared(); - - while(!calls.empty()) { - calls.front()(); - calls.pop(); - } - - } - - delete [] fdarray; - - runLock.lock_shared(); - } - - runLock.unlock_shared(); -} - -} -} diff --git a/src/Net/FdManager.h b/src/Net/FdManager.h deleted file mode 100644 index 1cb95bc..0000000 --- a/src/Net/FdManager.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * FdManager.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_FDMANAGER_H_ -#define MAD_NET_FDMANAGER_H_ - -#include <map> -#include <poll.h> - -#include <boost/function.hpp> -#include <boost/thread/shared_mutex.hpp> - -namespace Mad { -namespace Net { - -class ThreadManager; - -class FdManager { - private: - friend class ThreadManager; - - static FdManager fdManager; - - boost::shared_mutex runLock, handlerLock, eventLock; - bool running; - - std::map<int, struct pollfd> pollfds; - std::map<int, boost::function1<void, short> > handlers; - - int interruptPipe[2]; - - void readInterrupt(); - void interrupt(); - - FdManager(); - - void ioThread(); - void stopIOThread() { - runLock.lock(); - running = false; - runLock.unlock(); - - interrupt(); - } - - public: - virtual ~FdManager(); - - static FdManager *get() {return &fdManager;} - - bool registerFd(int fd, const boost::function1<void, short> &handler, short events = 0); - bool unregisterFd(int fd); - - bool setFdEvents(int fd, short events); - short getFdEvents(int fd); -}; - -} -} - -#endif /* MAD_NET_FDMANAGER_H_ */ diff --git a/src/Net/IPAddress.cpp b/src/Net/IPAddress.cpp deleted file mode 100644 index eb9d3be..0000000 --- a/src/Net/IPAddress.cpp +++ /dev/null @@ -1,92 +0,0 @@ -/* - * IPAddress.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "IPAddress.h" - -#include <cstdlib> - -namespace Mad { -namespace Net { - -IPAddress::IPAddress(uint16_t port0) : addr(INADDR_ANY), port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - sa.sin_addr.s_addr = INADDR_ANY; -} - -IPAddress::IPAddress(uint32_t address, uint16_t port0) : addr(address), port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - sa.sin_addr.s_addr = htonl(addr); -} - -IPAddress::IPAddress(const std::string &address) throw(Exception) { - std::string ip; - size_t pos = address.find_first_of(':'); - - if(pos == std::string::npos) { - ip = address; - // TODO Default port - port = 6666; - } - else { - ip = address.substr(0, pos); - - char *endptr; - port = std::strtol(address.substr(pos+1).c_str(), &endptr, 10); - if(*endptr != 0 || port == 0) - throw Exception(Exception::INVALID_ADDRESS); - } - - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - - if(ip == "*") - sa.sin_addr.s_addr = INADDR_ANY; - else if(!inet_pton(AF_INET, ip.c_str(), &sa.sin_addr)) - throw Exception(Exception::INVALID_ADDRESS); - - addr = ntohl(sa.sin_addr.s_addr); -} - -IPAddress::IPAddress(const std::string &address, uint16_t port0) throw(Exception) : port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - - if(!inet_pton(AF_INET, address.c_str(), &sa.sin_addr)) - throw Exception(Exception::INVALID_ADDRESS); - - addr = ntohl(sa.sin_addr.s_addr); -} - -IPAddress::IPAddress(const struct sockaddr_in &address) : sa(address) { - port = ntohs(sa.sin_port); - addr = ntohl(sa.sin_addr.s_addr); -} - -std::string IPAddress::getAddressString() const { - char buf[INET_ADDRSTRLEN]; - uint32_t address = htonl(addr); - - inet_ntop(AF_INET, &address, buf, sizeof(buf)); - return std::string(buf); -} - -} -} diff --git a/src/Net/IPAddress.h b/src/Net/IPAddress.h deleted file mode 100644 index 3541891..0000000 --- a/src/Net/IPAddress.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * IPAddress.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_IPADDRESS_H_ -#define MAD_NET_IPADDRESS_H_ - -#include "Exception.h" - -#include <string> -#include <arpa/inet.h> -#include <stdint.h> - -namespace Mad { -namespace Net { - -class IPAddress { - private: - uint32_t addr; - uint16_t port; - struct sockaddr_in sa; - - public: - // TODO Default port - IPAddress(uint16_t port0 = 6666); - IPAddress(uint32_t address, uint16_t port0); - IPAddress(const std::string &address) throw(Exception); - IPAddress(const std::string &address, uint16_t port0) throw(Exception); - IPAddress(const struct sockaddr_in &address); - - uint32_t getAddress() const {return addr;} - uint16_t getPort() const {return port;} - - std::string getAddressString() const; - - struct sockaddr* getSockAddr() {return (struct sockaddr*)&sa;} - socklen_t getSockAddrLength() const {return sizeof(sa);} -}; - -} -} - -#endif /*MAD_NET_IPADDRESS_H_*/ diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 11cbaf5..6f49a74 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -18,8 +18,6 @@ */ #include "Listener.h" -#include "FdManager.h" -#include "ServerConnection.h" #include <Common/Logger.h> @@ -30,26 +28,29 @@ namespace Mad { namespace Net { -void Listener::acceptHandler(int) { - int sd; - struct sockaddr_in sa; - socklen_t addrlen = sizeof(sa); +void Listener::handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con) { + if(error) + return; + + { + boost::lock_guard<boost::shared_mutex> lock(con->connectionLock); + con->state = ServerConnection::CONNECT; - while((sd = accept(sock, (struct sockaddr*)&sa, &addrlen)) >= 0) { - ServerConnection *con = new ServerConnection(sd, IPAddress(sa), dh_params, x905CertFile, x905KeyFile); - boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::connectHandler, this, con)); - boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::disconnectHandler, this, con)); + boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::handleConnect, this, con)); + boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::handleDisconnect, this, con)); connections.insert(std::make_pair(con, std::make_pair(con1, con2))); - addrlen = sizeof(sa); + con->socket.async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&ServerConnection::handleHandshake, con, boost::asio::placeholders::error)); } -} + con.reset(new ServerConnection(sslContext)); + acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); +} -void Listener::connectHandler(ServerConnection *con) { - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); +void Listener::handleConnect(boost::shared_ptr<ServerConnection> con) { + std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); if(it == connections.end()) return; @@ -62,67 +63,33 @@ void Listener::connectHandler(ServerConnection *con) { signal(con); } -void Listener::disconnectHandler(ServerConnection *con) { - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); - - if(it == connections.end()) - return; - - delete it->first; - connections.erase(it); +void Listener::handleDisconnect(boost::shared_ptr<ServerConnection> con) { + connections.erase(con); } -Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0) throw(Exception) -: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0) { - gnutls_dh_params_init(&dh_params); - gnutls_dh_params_generate2(dh_params, 768); +Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, + const boost::asio::ip::tcp::endpoint &address0) throw(Exception) +: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0), acceptor(Connection::ioService, address), +sslContext(Connection::ioService, boost::asio::ssl::context::sslv23) +{ + sslContext.set_options(boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use); + sslContext.use_certificate_chain_file(x905CertFile0); + sslContext.use_private_key_file(x905KeyFile0, boost::asio::ssl::context::pem); - sock = socket(PF_INET, SOCK_STREAM, 0); - if(sock < 0) - throw Exception("socket()", Exception::INTERNAL_ERRNO, errno); - // Set non-blocking flag - int flags = fcntl(sock, F_GETFL, 0); - - if(flags < 0) { - close(sock); - - throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno); - } - - fcntl(sock, F_SETFL, flags | O_NONBLOCK); - - // Don't linger - struct linger linger = {1, 0}; - setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); - - if(bind(sock, address.getSockAddr(), address.getSockAddrLength()) < 0) { - close(sock); - - throw Exception("bind()", Exception::INTERNAL_ERRNO, errno); - } - - if(listen(sock, 64) < 0) { - close(sock); - - throw Exception("listen()", Exception::INTERNAL_ERRNO, errno); - } - - FdManager::get()->registerFd(sock, boost::bind(&Listener::acceptHandler, this, _1), POLLIN); + boost::shared_ptr<ServerConnection> con(new ServerConnection(sslContext)); + acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); } Listener::~Listener() { - for(std::map<ServerConnection*,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) { + for(std::map<boost::shared_ptr<ServerConnection>,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) { con->first->disconnect(); - delete con->first; + // TODO wait... } - - shutdown(sock, SHUT_RDWR); - close(sock); - - gnutls_dh_params_deinit(dh_params); } } diff --git a/src/Net/Listener.h b/src/Net/Listener.h index 26dffab..0833cdf 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -20,46 +20,45 @@ #ifndef MAD_NET_LISTENER_H_ #define MAD_NET_LISTENER_H_ -#include "IPAddress.h" - -#include <gnutls/gnutls.h> #include <map> #include <string> -#include <boost/signal.hpp> +#include "Connection.h" +#include "Exception.h" namespace Mad { namespace Net { -class ServerConnection; - // TODO XXX Thread-safeness XXX -class Listener { +class Listener : boost::noncopyable { private: - std::string x905CertFile, x905KeyFile; - IPAddress address; - int sock; + class ServerConnection : public Connection { + public: + friend class Listener; - gnutls_dh_params_t dh_params; + ServerConnection(boost::asio::ssl::context &sslContext) : Connection(sslContext) {} + }; - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> > connections; + std::string x905CertFile, x905KeyFile; + boost::asio::ip::tcp::endpoint address; + boost::asio::ip::tcp::acceptor acceptor; + boost::asio::ssl::context sslContext; - boost::signal1<void, ServerConnection*> signal; + std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> > connections; - void acceptHandler(int); + boost::signal1<void, boost::shared_ptr<Connection> > signal; - void connectHandler(ServerConnection *con); - void disconnectHandler(ServerConnection *con); + void handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con); - // Prevent shallow copy - Listener(const Listener &o); - Listener& operator=(const Listener &o); + void handleConnect(boost::shared_ptr<ServerConnection> con); + void handleDisconnect(boost::shared_ptr<ServerConnection> con); public: - Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0 = IPAddress()) throw(Exception); + Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, + const boost::asio::ip::tcp::endpoint &address0 = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 6666)) throw(Exception); virtual ~Listener(); - boost::signal1<void, ServerConnection*>& signalNewConnection() {return signal;} + boost::signal1<void, boost::shared_ptr<Connection> >& signalNewConnection() {return signal;} }; } diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp deleted file mode 100644 index 1f01ce5..0000000 --- a/src/Net/ServerConnection.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * ServerConnection.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "ServerConnection.h" -#include "FdManager.h" -#include "IPAddress.h" - -#include <boost/thread/locks.hpp> - -#include <cstring> -#include <cerrno> -#include <sys/socket.h> -#include <fcntl.h> - -namespace Mad { -namespace Net { - -void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { - if(length != sizeof(ConnectionHeader)) - // Error... disconnect - return; - - const ConnectionHeader *header = (const ConnectionHeader*)data; - - if(header->m != 'M' || header->a != 'A' || header->d != 'D') - // Error... disconnect - return; - - if(header->protVerMin > 1 || header->protVerMax < 1) - // Unsupported protocol... disconnect - return; - - if(header->type == 'C') - daemon = false; - else if(header->type == 'D') - daemon = true; - else - // Error... disconnect - return; - - ConnectionHeader header2 = {'M', 'A', 'D', 0, 0, 1, 1, 0}; - - enterReceiveLoop(); - - rawSend((uint8_t*)&header2, sizeof(header2)); -} - -ServerConnection::ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905CertFile, const std::string &x905KeyFile) -: daemon(false) { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - sock = sock0; - - peer = new IPAddress(address); - - gnutls_certificate_set_dh_params(x509_cred, dh_params); - gnutls_certificate_set_x509_key_file(x509_cred, x905CertFile.c_str(), x905KeyFile.c_str(), GNUTLS_X509_FMT_PEM); - - gnutls_init(&session, GNUTLS_SERVER); - gnutls_set_default_priority(session); - gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); - gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - - FdManager::get()->registerFd(sock, boost::bind(&ServerConnection::sendReceive, this, _1)); - - state = CONNECT; - - lock.unlock(); - - updateEvents(); -} - -} -} diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h deleted file mode 100644 index d52cd7c..0000000 --- a/src/Net/ServerConnection.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * ServerConnection.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_SERVERCONNECTION_H_ -#define MAD_NET_SERVERCONNECTION_H_ - -#include "Connection.h" -#include <string> - -namespace Mad { -namespace Net { - -class Listener; - -class ServerConnection : public Connection { - friend class Listener; - - private: - IPAddress *peer; - - bool daemon; - - gnutls_anon_server_credentials_t anoncred; - - void connectionHeaderReceiveHandler(const void *data, unsigned long length); - - protected: - ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905certFile, const std::string &x905keyFile); - - virtual void connectionHeader() { - rawReceive(sizeof(ConnectionHeader), boost::bind(&ServerConnection::connectionHeaderReceiveHandler, this, _1, _2)); - } - - public: - bool isDaemonConnection() const {return daemon;} -}; - -} -} - -#endif /*MAD_NET_SERVERCONNECTION_H_*/ diff --git a/src/Net/ThreadManager.cpp b/src/Net/ThreadManager.cpp index 71a754e..0fb0716 100644 --- a/src/Net/ThreadManager.cpp +++ b/src/Net/ThreadManager.cpp @@ -18,11 +18,13 @@ */ #include "ThreadManager.h" -#include "FdManager.h" +#include "Connection.h" #include <Common/Logger.h> #include <Common/LogManager.h> +#include <boost/bind.hpp> + #include <fcntl.h> namespace Mad { @@ -97,10 +99,12 @@ void ThreadManager::doInit() { threadLock.lock(); + ioWorker.reset(new boost::asio::io_service::work(Connection::ioService)); + mainThreadId = boost::this_thread::get_id(); - workerThread = new boost::thread(std::mem_fun(&ThreadManager::workerFunc), this); - loggerThread = new boost::thread(std::mem_fun(&Common::LogManager::loggerThread), Common::LogManager::get()); - ioThread = new boost::thread(std::mem_fun(&FdManager::ioThread), FdManager::get()); + workerThread = new boost::thread(&ThreadManager::workerFunc, this); + loggerThread = new boost::thread(&Common::LogManager::loggerThread, Common::LogManager::get()); + ioThread = new boost::thread((std::size_t(boost::asio::io_service::*)())&boost::asio::io_service::run, &Connection::ioService); threadLock.unlock(); } @@ -128,7 +132,7 @@ void ThreadManager::doDeinit() { threads.join_all(); // IO thread is next - FdManager::get()->stopIOThread(); + ioWorker.reset(); ioThread->join(); delete ioThread; diff --git a/src/Net/ThreadManager.h b/src/Net/ThreadManager.h index fd903af..2c57747 100644 --- a/src/Net/ThreadManager.h +++ b/src/Net/ThreadManager.h @@ -27,7 +27,8 @@ #include <queue> #include <set> -#include <boost/function.hpp> +#include <boost/asio.hpp> + #include <boost/thread/thread.hpp> #include <boost/thread/condition_variable.hpp> #include <boost/thread/locks.hpp> @@ -50,11 +51,14 @@ class ThreadManager : public Common::Initializable { boost::condition_variable workCond; std::queue<boost::function0<void> > work; + boost::scoped_ptr<boost::asio::io_service::work> ioWorker; + static ThreadManager threadManager; ThreadManager() {} void workerFunc(); + void ioFunc(); void threadFinished(boost::thread *thread) { threadLock.lock(); diff --git a/src/Server/ConnectionManager.cpp b/src/Server/ConnectionManager.cpp index 6ef918a..bbaf673 100644 --- a/src/Server/ConnectionManager.cpp +++ b/src/Server/ConnectionManager.cpp @@ -33,8 +33,6 @@ #include "RequestHandlers/LogRequestHandler.h" #include "RequestHandlers/UserInfoRequestHandler.h" #include "RequestHandlers/UserListRequestHandler.h" -#include <Net/FdManager.h> -#include <Net/ServerConnection.h> #include <Net/Packet.h> #include <Net/Listener.h> @@ -46,49 +44,46 @@ namespace Server { ConnectionManager ConnectionManager::connectionManager; -bool ConnectionManager::Connection::send(const Net::Packet &packet) { +bool ConnectionManager::ServerConnection::send(const Net::Packet &packet) { return connection->send(packet); } -ConnectionManager::Connection::Connection(Net::ServerConnection *connection0) -: connection(connection0), type(connection0->isDaemonConnection() ? DAEMON : CLIENT), hostInfo(0) { - connection->signalReceive().connect(boost::bind(&Connection::receive, this, _1)); +ConnectionManager::ServerConnection::ServerConnection(boost::shared_ptr<Net::Connection> connection0) +: connection(connection0), type(UNKNOWN), hostInfo(0) { + connection->signalReceive().connect(boost::bind(&ServerConnection::receive, this, _1)); } -ConnectionManager::Connection::~Connection() { - delete connection; -} - -bool ConnectionManager::Connection::isConnected() const { +bool ConnectionManager::ServerConnection::isConnected() const { return connection->isConnected(); } -bool ConnectionManager::Connection::disconnect() { +bool ConnectionManager::ServerConnection::disconnect() { connection->disconnect(); return true; } -void* ConnectionManager::Connection::getCertificate(size_t *size) const { +/*void* ConnectionManager::ServerConnection::getCertificate(size_t *size) const { const gnutls_datum_t *cert = connection->getCertificate(); *size = cert->size; return cert->data; } -void* ConnectionManager::Connection::getPeerCertificate(size_t *size) const { +void* ConnectionManager::ServerConnection::getPeerCertificate(size_t *size) const { const gnutls_datum_t *cert = connection->getPeerCertificate(); *size = cert->size; return cert->data; -} +}*/ void ConnectionManager::updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state) { hostInfo->setState(state); - for(std::set<Connection*>::iterator con = connections.begin(); con != connections.end(); ++con) { - if((*con)->getConnectionType() == Connection::CLIENT) - Common::RequestManager::get()->sendRequest<Requests::DaemonStateUpdateRequest>(*con, boost::bind(&ConnectionManager::updateStateFinished, this, _1), hostInfo->getName(), state); + + for(std::set<boost::shared_ptr<ServerConnection> >::iterator con = connections.begin(); con != connections.end(); ++con) { + if((*con)->getConnectionType() == ServerConnection::CLIENT) + Common::RequestManager::get()->sendRequest<Requests::DaemonStateUpdateRequest>(con->get(), boost::bind(&ConnectionManager::updateStateFinished, this, _1), hostInfo->getName(), state); } } @@ -98,7 +93,7 @@ bool ConnectionManager::handleConfigEntry(const Common::ConfigEntry &entry, bool if(entry[0].getKey().matches("Listen") && entry[1].empty()) { try { - listenerAddresses.push_back(Net::IPAddress(entry[0][0])); + listenerAddresses.push_back(boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(entry[0][0]), 6666)); } catch(Net::Exception &e) { // TODO Log error @@ -147,8 +142,8 @@ bool ConnectionManager::handleConfigEntry(const Common::ConfigEntry &entry, bool void ConnectionManager::configFinished() { if(listenerAddresses.empty()) { try { - Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile); - listener->signalNewConnection().connect(boost::bind(&ConnectionManager::newConnectionHandler, this, _1)); + boost::shared_ptr<Net::Listener> listener(new Net::Listener(x509CertFile, x509KeyFile)); + listener->signalNewConnection().connect(boost::bind(&ConnectionManager::handleNewConnection, this, _1)); listeners.push_back(listener); } catch(Net::Exception &e) { @@ -156,10 +151,10 @@ void ConnectionManager::configFinished() { } } else { - for(std::vector<Net::IPAddress>::const_iterator address = listenerAddresses.begin(); address != listenerAddresses.end(); ++address) { + for(std::vector<boost::asio::ip::tcp::endpoint>::const_iterator address = listenerAddresses.begin(); address != listenerAddresses.end(); ++address) { try { - Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile, *address); - listener->signalNewConnection().connect(boost::bind(&ConnectionManager::newConnectionHandler, this, _1)); + boost::shared_ptr<Net::Listener> listener(new Net::Listener(x509CertFile, x509KeyFile, *address)); + listener->signalNewConnection().connect(boost::bind(&ConnectionManager::handleNewConnection, this, _1)); listeners.push_back(listener); } catch(Net::Exception &e) { @@ -169,28 +164,27 @@ void ConnectionManager::configFinished() { } } -void ConnectionManager::newConnectionHandler(Net::ServerConnection *con) { - Connection *connection = new Connection(con); - con->signalDisconnected().connect(boost::bind(&ConnectionManager::disconnectHandler, this, connection)); +void ConnectionManager::handleNewConnection(boost::shared_ptr<Net::Connection> con) { + boost::shared_ptr<ServerConnection> connection(new ServerConnection(con)); + con->signalDisconnected().connect(boost::bind(&ConnectionManager::handleDisconnect, this, connection)); connections.insert(connection); - Common::RequestManager::get()->registerConnection(connection); + Common::RequestManager::get()->registerConnection(connection.get()); } -void ConnectionManager::disconnectHandler(Connection *con) { - if(con->isIdentified()) +void ConnectionManager::handleDisconnect(boost::shared_ptr<ServerConnection> con) { + if(con->getHostInfo()) updateState(con->getHostInfo(), Common::HostInfo::INACTIVE); connections.erase(con); - Common::RequestManager::get()->unregisterConnection(con); - delete con; + Common::RequestManager::get()->unregisterConnection(con.get()); } void ConnectionManager::doInit() { Common::RequestManager::get()->setServer(true); - Common::RequestManager::get()->registerPacketType<RequestHandlers::GSSAPIAuthRequestHandler>("AuthGSSAPI"); + //Common::RequestManager::get()->registerPacketType<RequestHandlers::GSSAPIAuthRequestHandler>("AuthGSSAPI"); Common::RequestManager::get()->registerPacketType<RequestHandlers::DaemonCommandRequestHandler>("DaemonCommand"); Common::RequestManager::get()->registerPacketType<RequestHandlers::DaemonFSInfoRequestHandler>("DaemonFSInfo"); Common::RequestManager::get()->registerPacketType<Common::RequestHandlers::FSInfoRequestHandler>("FSInfo"); @@ -204,11 +198,9 @@ void ConnectionManager::doInit() { } void ConnectionManager::doDeinit() { - for(std::set<Connection*>::iterator con = connections.begin(); con != connections.end(); ++con) - delete *con; + connections.clear(); - - Common::RequestManager::get()->unregisterPacketType("AuthGSSAPI"); + //Common::RequestManager::get()->unregisterPacketType("AuthGSSAPI"); Common::RequestManager::get()->unregisterPacketType("DaemonCommand"); Common::RequestManager::get()->unregisterPacketType("DaemonFSInfo"); Common::RequestManager::get()->unregisterPacketType("FSInfo"); @@ -221,7 +213,7 @@ void ConnectionManager::doDeinit() { Common::RequestManager::get()->unregisterPacketType("Log"); } -Common::Connection* ConnectionManager::getDaemonConnection(const std::string &name) const throw (Net::Exception&) { +boost::shared_ptr<Common::Connection> ConnectionManager::getDaemonConnection(const std::string &name) const throw (Net::Exception&) { const Common::HostInfo *hostInfo; try { @@ -232,7 +224,7 @@ Common::Connection* ConnectionManager::getDaemonConnection(const std::string &na } if(hostInfo->getState() != Common::HostInfo::INACTIVE) { - for(std::set<Connection*>::const_iterator it = connections.begin(); it != connections.end(); ++it) { + for(std::set<boost::shared_ptr<ServerConnection> >::const_iterator it = connections.begin(); it != connections.end(); ++it) { if((*it)->getHostInfo() == hostInfo) { return *it; } @@ -243,7 +235,7 @@ Common::Connection* ConnectionManager::getDaemonConnection(const std::string &na } std::string ConnectionManager::getDaemonName(const Common::Connection *con) const throw (Net::Exception&) { - const Connection *connection = dynamic_cast<const Connection*>(con); + const ServerConnection *connection = dynamic_cast<const ServerConnection*>(con); if(connection) { if(connection->isIdentified()) { @@ -257,9 +249,9 @@ std::string ConnectionManager::getDaemonName(const Common::Connection *con) cons void ConnectionManager::identifyDaemonConnection(Common::Connection *con, const std::string &name) throw (Net::Exception&) { // TODO Logging - Connection *connection = dynamic_cast<Connection*>(con); + ServerConnection *connection = dynamic_cast<ServerConnection*>(con); - if(!connection || (connection->getConnectionType() != Connection::DAEMON)) + if(!connection) throw Net::Exception(Net::Exception::INVALID_ACTION); if(connection->isIdentified()) @@ -284,6 +276,18 @@ void ConnectionManager::identifyDaemonConnection(Common::Connection *con, const Common::Logger::logf("Identified as '%s'.", name.c_str()); } +void ConnectionManager::identifyClientConnection(Common::Connection *con) throw (Net::Exception&) { + ServerConnection *connection = dynamic_cast<ServerConnection*>(con); + + if(!connection) + throw Net::Exception(Net::Exception::INVALID_ACTION); + + if(connection->isIdentified()) + throw Net::Exception(Net::Exception::ALREADY_IDENTIFIED); + + connection->identify(); +} + std::vector<Common::HostInfo> ConnectionManager::getDaemonList() const { std::vector<Common::HostInfo> ret; diff --git a/src/Server/ConnectionManager.h b/src/Server/ConnectionManager.h index 7d97edc..710665d 100644 --- a/src/Server/ConnectionManager.h +++ b/src/Server/ConnectionManager.h @@ -20,38 +20,38 @@ #ifndef MAD_SERVER_CONNECTIONMANAGER_H_ #define MAD_SERVER_CONNECTIONMANAGER_H_ -#include <list> -#include <vector> -#include <map> -#include <set> - #include <Common/Configurable.h> #include <Common/HostInfo.h> #include <Common/Initializable.h> #include <Common/RequestManager.h> -#include <Net/IPAddress.h> +#include <list> +#include <vector> +#include <map> +#include <set> + +#include <boost/asio.hpp> namespace Mad { namespace Net { +class Connection; class Listener; -class ServerConnection; class Packet; } namespace Server { -class ConnectionManager : public Common::Configurable, public Common::Initializable { +class ConnectionManager : public Common::Configurable, public Common::Initializable, boost::noncopyable { private: - class Connection : public Common::Connection { + class ServerConnection : public Common::Connection { public: enum ConnectionType { - DAEMON, CLIENT + UNKNOWN = 0, DAEMON, CLIENT }; private: - Net::ServerConnection *connection; + boost::shared_ptr<Net::Connection> connection; ConnectionType type; Common::HostInfo *hostInfo; @@ -59,14 +59,13 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa virtual bool send(const Net::Packet &packet); public: - Connection(Net::ServerConnection *connection0); - virtual ~Connection(); + ServerConnection(boost::shared_ptr<Net::Connection> connection0); bool isConnected() const; virtual bool disconnect(); - virtual void* getCertificate(size_t *size) const; - virtual void* getPeerCertificate(size_t *size) const; + //virtual void* getCertificate(size_t *size) const; + //virtual void* getPeerCertificate(size_t *size) const; ConnectionType getConnectionType() const { return type; @@ -77,10 +76,15 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa } bool isIdentified() const { - return hostInfo; + return (type != UNKNOWN); + } + + void identify() { + type = CLIENT; } void identify(Common::HostInfo *info) { + type = DAEMON; hostInfo = info; } }; @@ -89,17 +93,13 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa std::string x509TrustFile, x509CrlFile, x509CertFile, x509KeyFile; - std::vector<Net::IPAddress> listenerAddresses; - std::list<Net::Listener*> listeners; + std::vector<boost::asio::ip::tcp::endpoint> listenerAddresses; + std::list<boost::shared_ptr<Net::Listener> > listeners; - std::set<Connection*> connections; + std::set<boost::shared_ptr<ServerConnection> > connections; std::map<std::string,Common::HostInfo> daemonInfo; - // Prevent shallow copy - ConnectionManager(const ConnectionManager &o); - ConnectionManager& operator=(const ConnectionManager &o); - void updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state); void updateStateFinished(const Common::Request&) { // TODO Error handling (updateStateFinished) @@ -107,8 +107,8 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa ConnectionManager() {} - void newConnectionHandler(Net::ServerConnection *con); - void disconnectHandler(Connection *con); + void handleNewConnection(boost::shared_ptr<Net::Connection> con); + void handleDisconnect(boost::shared_ptr<ServerConnection> con); protected: virtual bool handleConfigEntry(const Common::ConfigEntry &entry, bool handled); @@ -122,10 +122,12 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa return &connectionManager; } - Common::Connection* getDaemonConnection(const std::string &name) const throw (Net::Exception&); + boost::shared_ptr<Common::Connection> getDaemonConnection(const std::string &name) const throw (Net::Exception&); std::string getDaemonName(const Common::Connection *con) const throw (Net::Exception&); void identifyDaemonConnection(Common::Connection *con, const std::string &name) throw (Net::Exception&); + void identifyClientConnection(Common::Connection *con) throw (Net::Exception&); + std::vector<Common::HostInfo> getDaemonList() const; }; diff --git a/src/Server/RequestHandlers/CMakeLists.txt b/src/Server/RequestHandlers/CMakeLists.txt index 2c6305d..e340a1e 100644 --- a/src/Server/RequestHandlers/CMakeLists.txt +++ b/src/Server/RequestHandlers/CMakeLists.txt @@ -3,7 +3,7 @@ include_directories(${INCLUDES}) add_library(ServerRequestHandlers DaemonCommandRequestHandler.cpp DaemonFSInfoRequestHandler.cpp DaemonListRequestHandler.cpp DaemonStatusRequestHandler.cpp - GSSAPIAuthRequestHandler.cpp IdentifyRequestHandler.cpp + IdentifyRequestHandler.cpp LogRequestHandler.cpp UserInfoRequestHandler.cpp UserListRequestHandler.cpp ) diff --git a/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp b/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp index 17a7e5d..e4b511e 100644 --- a/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp +++ b/src/Server/RequestHandlers/DaemonCommandRequestHandler.cpp @@ -46,8 +46,8 @@ void DaemonCommandRequestHandler::handlePacket(const Common::XmlPacket &packet) std::string command = packet["command"]; try { - Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]); - Common::RequestManager::get()->sendRequest<Requests::CommandRequest>(daemonCon, + boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]); + Common::RequestManager::get()->sendRequest<Requests::CommandRequest>(daemonCon.get(), boost::bind(&DaemonCommandRequestHandler::requestFinished, this, _1), command == "reboot"); } catch(Net::Exception &e) { diff --git a/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp b/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp index df57a94..11a3f09 100644 --- a/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp +++ b/src/Server/RequestHandlers/DaemonFSInfoRequestHandler.cpp @@ -44,8 +44,8 @@ void DaemonFSInfoRequestHandler::handlePacket(const Common::XmlPacket &packet) { // TODO Require authentication try { - Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]); - Common::RequestManager::get()->sendRequest<Common::Requests::FSInfoRequest>(daemonCon, + boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(packet["daemon"]); + Common::RequestManager::get()->sendRequest<Common::Requests::FSInfoRequest>(daemonCon.get(), boost::bind(&DaemonFSInfoRequestHandler::requestFinished, this, _1)); } catch(Net::Exception &e) { diff --git a/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp b/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp index 3d99c57..c1e2fcd 100644 --- a/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp +++ b/src/Server/RequestHandlers/DaemonStatusRequestHandler.cpp @@ -46,8 +46,8 @@ void DaemonStatusRequestHandler::handlePacket(const Common::XmlPacket &packet) { std::string daemonName = packet["daemonName"]; try { - Common::Connection *daemonCon = ConnectionManager::get()->getDaemonConnection(daemonName); - Common::RequestManager::get()->sendRequest<Common::Requests::StatusRequest>(daemonCon, + boost::shared_ptr<Common::Connection> daemonCon = ConnectionManager::get()->getDaemonConnection(daemonName); + Common::RequestManager::get()->sendRequest<Common::Requests::StatusRequest>(daemonCon.get(), boost::bind(&DaemonStatusRequestHandler::requestFinished, this, _1)); } catch(Net::Exception &e) { diff --git a/src/Server/RequestHandlers/IdentifyRequestHandler.cpp b/src/Server/RequestHandlers/IdentifyRequestHandler.cpp index f69b3f5..b47b505 100644 --- a/src/Server/RequestHandlers/IdentifyRequestHandler.cpp +++ b/src/Server/RequestHandlers/IdentifyRequestHandler.cpp @@ -42,7 +42,10 @@ void IdentifyRequestHandler::handlePacket(const Common::XmlPacket &packet) { // TODO Require authentication try { - ConnectionManager::get()->identifyDaemonConnection(getConnection(), packet["hostname"]); + if(packet["hostname"].isEmpty()) + ConnectionManager::get()->identifyClientConnection(getConnection()); + else + ConnectionManager::get()->identifyDaemonConnection(getConnection(), packet["hostname"]); Common::XmlPacket ret; ret.setType("OK"); diff --git a/src/mad-server.conf b/src/mad-server.conf index bd4672c..45927cf 100644 --- a/src/mad-server.conf +++ b/src/mad-server.conf @@ -2,7 +2,7 @@ Logger Console Logger File "mad-server.log" -Listen * +#Listen * X509TrustFile ../Cert/ca-cert.pem diff --git a/src/mad-server.cpp b/src/mad-server.cpp index 6fb0e21..d55c049 100644 --- a/src/mad-server.cpp +++ b/src/mad-server.cpp @@ -22,18 +22,10 @@ #include "Net/ThreadManager.h" #include "Server/ConnectionManager.h" -#include <signal.h> - using namespace Mad; int main() { - sigset_t signals; - - sigemptyset(&signals); - sigaddset(&signals, SIGPIPE); - sigprocmask(SIG_BLOCK, &signals, 0); - Net::ThreadManager::get()->init(); Server::ConnectionManager::get()->init(); diff --git a/src/mad.cpp b/src/mad.cpp index 035c0be..22a3b44 100644 --- a/src/mad.cpp +++ b/src/mad.cpp @@ -17,8 +17,6 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#include "Net/FdManager.h" -#include "Net/IPAddress.h" #include "Net/ThreadManager.h" #include "Common/ConfigManager.h" #include "Common/LogManager.h" @@ -26,10 +24,10 @@ #include "Common/ModuleManager.h" #include "Common/RequestManager.h" #include "Common/ClientConnection.h" +#include "Common/Requests/IdentifyRequest.h" #include "Common/RequestHandlers/FSInfoRequestHandler.h" #include "Common/RequestHandlers/StatusRequestHandler.h" #include "Daemon/Backends/NetworkLogger.h" -#include "Daemon/Requests/IdentifyRequest.h" #include "Daemon/RequestHandlers/CommandRequestHandler.h" #include <unistd.h> @@ -57,7 +55,8 @@ int main() { Common::ClientConnection *connection = new Common::ClientConnection; try { - connection->connect(Net::IPAddress("127.0.0.1"), true); + connection->connect(boost::asio::ip::tcp::endpoint( + boost::asio::ip::address_v4::from_string("127.0.0.1"), 6666)); while(connection->isConnecting()) usleep(100000); @@ -71,7 +70,7 @@ int main() { //char hostname[256]; //gethostname(hostname, sizeof(hostname)); //Common::RequestManager::get()->sendRequest<Daemon::Requests::IdentifyRequest>(connection, sigc::ptr_fun(requestFinished), hostname); - Common::RequestManager::get()->sendRequest<Daemon::Requests::IdentifyRequest>(connection, &requestFinished, "test"); + Common::RequestManager::get()->sendRequest<Common::Requests::IdentifyRequest>(connection, &requestFinished, "test"); while(connection->isConnected()) usleep(100000); diff --git a/src/madc.cpp b/src/madc.cpp index 106617c..a902144 100644 --- a/src/madc.cpp +++ b/src/madc.cpp @@ -17,11 +17,10 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#include "Net/FdManager.h" -#include "Net/IPAddress.h" #include "Net/ThreadManager.h" #include "Common/ClientConnection.h" #include "Common/ConfigManager.h" +#include "Common/Requests/IdentifyRequest.h" #include "Common/Logger.h" #include "Common/RequestManager.h" #include "Client/CommandParser.h" @@ -35,37 +34,14 @@ using namespace Mad; -static void usage(const std::string &cmd) { - std::cerr << "Usage: " << cmd << " address[:port]" << std::endl; -} - -static void handleCommand(char *cmd) { - if(!cmd) - Client::CommandParser::get()->requestDisconnect(); - else if(!*cmd) - return; - else { - Client::CommandParser::get()->parse(cmd); - add_history(cmd); - } - - if(Client::CommandManager::get()->requestsActive()) { - rl_callback_handler_remove(); - Net::FdManager::get()->setFdEvents(STDIN_FILENO, 0); - } -} +static bool commandRunning = false; -static void charHandler(short events) { - if(events & POLLIN) - rl_callback_read_char(); +static void usage(const std::string &cmd) { + std::cerr << "Usage: " << cmd << " address" << std::endl; } -static void activateReadline() { - if(Client::CommandManager::get()->willDisconnect()) - return; - - rl_callback_handler_install("mad> ", handleCommand); - Net::FdManager::get()->setFdEvents(STDIN_FILENO, POLLIN); +static void commandFinished() { + commandRunning = false; } int main(int argc, char *argv[]) { @@ -82,17 +58,22 @@ int main(int argc, char *argv[]) { Common::ClientConnection *connection = new Common::ClientConnection; try { - connection->connect(Net::IPAddress(argv[1])); + connection->connect(boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(argv[1]), 6666)); std::cerr << "Connecting to " << argv[1] << "..." << std::flush; while(connection->isConnecting()) usleep(100000); - std::cerr << " connected." << std::endl; - Common::RequestManager::get()->registerConnection(connection); + commandRunning = true; + Common::RequestManager::get()->sendRequest<Common::Requests::IdentifyRequest>(connection, boost::bind(&commandFinished)); + while(commandRunning) { + usleep(100000); + } + + std::cerr << " connected." << std::endl; std::cerr << "Receiving host list..." << std::flush; Client::InformationManager::get()->updateDaemonList(connection); @@ -103,16 +84,24 @@ int main(int argc, char *argv[]) { std::cerr << " done." << std::endl << std::endl; Client::CommandParser::get()->setConnection(connection); - Client::CommandManager::get()->signalFinished().connect(&activateReadline); - - Net::FdManager::get()->registerFd(STDIN_FILENO,&charHandler); - - activateReadline(); - - while(connection->isConnected()) - usleep(100000); - - Net::FdManager::get()->unregisterFd(STDIN_FILENO); + Client::CommandManager::get()->signalFinished().connect(&commandFinished); + + while(connection->isConnected()) { + char *cmd = readline("mad> "); + + if(!cmd) { + commandRunning = true; + Client::CommandParser::get()->requestDisconnect(); + } + else if(*cmd) { + commandRunning = true; + Client::CommandParser::get()->parse(cmd); + add_history(cmd); + } + + while(Client::CommandManager::get()->requestsActive()) + usleep(100000); + } Common::RequestManager::get()->unregisterConnection(connection); } |