diff options
author | Matthias Schiffer <matthias@gamezock.de> | 2009-05-20 01:08:16 +0200 |
---|---|---|
committer | Matthias Schiffer <matthias@gamezock.de> | 2009-05-20 01:08:16 +0200 |
commit | 776377bb21ee1cfe0bcdbc000f7c6fa0be227226 (patch) | |
tree | a60e8bdd92678dece5fbb96e75535eba2b61b7da /src/Net | |
parent | badc0da3b74d99c90b7b28180d08cd6d08830254 (diff) | |
download | mad-776377bb21ee1cfe0bcdbc000f7c6fa0be227226.tar mad-776377bb21ee1cfe0bcdbc000f7c6fa0be227226.zip |
Netzwerk-Code auf boost::asio umgestellt
Diffstat (limited to 'src/Net')
-rw-r--r-- | src/Net/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/Net/ClientConnection.cpp | 87 | ||||
-rw-r--r-- | src/Net/ClientConnection.h | 23 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 364 | ||||
-rw-r--r-- | src/Net/Connection.h | 144 | ||||
-rw-r--r-- | src/Net/Exception.cpp | 3 | ||||
-rw-r--r-- | src/Net/Exception.h | 2 | ||||
-rw-r--r-- | src/Net/FdManager.cpp | 174 | ||||
-rw-r--r-- | src/Net/FdManager.h | 77 | ||||
-rw-r--r-- | src/Net/IPAddress.cpp | 92 | ||||
-rw-r--r-- | src/Net/IPAddress.h | 58 | ||||
-rw-r--r-- | src/Net/Listener.cpp | 95 | ||||
-rw-r--r-- | src/Net/Listener.h | 41 | ||||
-rw-r--r-- | src/Net/ServerConnection.cpp | 90 | ||||
-rw-r--r-- | src/Net/ServerConnection.h | 57 | ||||
-rw-r--r-- | src/Net/ThreadManager.cpp | 14 | ||||
-rw-r--r-- | src/Net/ThreadManager.h | 6 |
17 files changed, 259 insertions, 1074 deletions
diff --git a/src/Net/CMakeLists.txt b/src/Net/CMakeLists.txt index aa3f857..fae358b 100644 --- a/src/Net/CMakeLists.txt +++ b/src/Net/CMakeLists.txt @@ -2,7 +2,7 @@ include_directories(${INCLUDES}) link_directories(${Boost_LIBRARY_DIRS}) add_library(Net - ClientConnection.cpp Connection.cpp Exception.cpp FdManager.cpp IPAddress.cpp - Listener.cpp Packet.cpp ServerConnection.cpp ThreadManager.cpp + ClientConnection.cpp Connection.cpp Exception.cpp Listener.cpp + Packet.cpp ThreadManager.cpp ) -target_link_libraries(Net ${Boost_LIBRARIES} ${GNUTLS_LIBRARIES}) +target_link_libraries(Net ${Boost_LIBRARIES} ${OPENSSL_LIBRARIES}) diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 087d95f..9cdf796 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -18,99 +18,36 @@ */ #include "ClientConnection.h" -#include "FdManager.h" -#include "IPAddress.h" -#include <boost/thread/locks.hpp> - -#include <cstring> -#include <cerrno> -#include <sys/socket.h> -#include <fcntl.h> +#include <Common/Logger.h> namespace Mad { namespace Net { -// TODO Error handling -void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { - if(length != sizeof(ConnectionHeader)) - // Error... disconnect - return; - - const ConnectionHeader *header = (const ConnectionHeader*)(data); - - if(header->m != 'M' || header->a != 'A' || header->d != 'D') - // Error... disconnect +void ClientConnection::handleConnect(const boost::system::error_code& error) { + if(error) { + // TODO Error handling + doDisconnect(); return; + } - if(header->protVerMin != 1) - // Unsupported protocol... disconnect - return; - - enterReceiveLoop(); -} - -void ClientConnection::connectionHeader() { - ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1}; + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - rawSend((uint8_t*)&header, sizeof(header)); - rawReceive(sizeof(ConnectionHeader), boost::bind(&ClientConnection::connectionHeaderReceiveHandler, this, _1, _2)); + socket.async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error)); } -void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exception) { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - daemon = daemon0; +void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception) { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); if(_isConnected()) { return; // TODO Error } - sock = socket(PF_INET, SOCK_STREAM, 0); - if(sock < 0) { - throw Exception("socket()", Exception::INTERNAL_ERRNO, errno); - } - - if(peer) - delete peer; - peer = new IPAddress(address); - - if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) { - close(sock); - delete peer; - peer = 0; - - throw Exception("connect()", Exception::INTERNAL_ERRNO, errno); - } - - // Set non-blocking flag - int flags = fcntl(sock, F_GETFL, 0); - - if(flags < 0) { - close(sock); - - throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno); - } - - fcntl(sock, F_SETFL, flags | O_NONBLOCK); - - // Don't linger - struct linger linger = {1, 0}; - setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); - - gnutls_init(&session, GNUTLS_CLIENT); - gnutls_set_default_priority(session); - gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); - gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - - FdManager::get()->registerFd(sock, boost::bind(&ClientConnection::sendReceive, this, _1)); - + peer = address; state = CONNECT; - lock.unlock(); - - updateEvents(); + socket.lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index bdd7872..f93c2fc 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -23,24 +23,27 @@ #include "Connection.h" #include "Exception.h" +#include <boost/utility/base_from_member.hpp> + + namespace Mad { namespace Net { class IPAddress; -class ClientConnection : public Connection { +class ClientConnection : private boost::base_from_member<boost::asio::ssl::context>, public Connection { private: - bool daemon; - - void connectionHeaderReceiveHandler(const void *data, unsigned long length); - - protected: - virtual void connectionHeader(); + void handleConnect(const boost::system::error_code& error); public: - ClientConnection() : daemon(0) {} - - void connect(const IPAddress &address, bool daemon0 = false) throw(Exception); + ClientConnection() + : boost::base_from_member<boost::asio::ssl::context>(boost::ref(Connection::ioService), boost::asio::ssl::context::sslv23), + Connection(member) + { + member.set_verify_mode(boost::asio::ssl::context::verify_none); + } + + void connect(const boost::asio::ip::tcp::endpoint &address) throw(Exception); }; } diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 4e5029f..b9691cb 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -18,361 +18,225 @@ */ #include "Connection.h" -#include "FdManager.h" -#include "IPAddress.h" #include "ThreadManager.h" -#include <cstring> -#include <sys/socket.h> +#include <Common/Logger.h> +#include <cstring> #include <boost/bind.hpp> namespace Mad { namespace Net { - -Connection::StaticInit Connection::staticInit; +boost::asio::io_service Connection::ioService; Connection::~Connection() { if(_isConnected()) doDisconnect(); - - if(transR.data) - delete [] transR.data; - - while(!_sendQueueEmpty()) { - delete [] transS.front().data; - transS.pop(); - } - - gnutls_certificate_free_credentials(x509_cred); - - if(peer) - delete peer; -} - -void Connection::handshake() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != CONNECT) - return; - - state = HANDSHAKE; - lock.unlock(); - - doHandshake(); } -void Connection::bye() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != DISCONNECT) - return; - - state = BYE; - lock.unlock(); +void Connection::handleHandshake(const boost::system::error_code& error) { + if(error) { + Common::Logger::logf("Error: %s", error.message().c_str()); - doBye(); -} - -void Connection::doHandshake() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); - if(state != HANDSHAKE) + // TODO Error handling + doDisconnect(); return; + } - int ret = gnutls_handshake(session); - if(ret < 0) { - lock.unlock(); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + state = CONNECTED; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } + receiving = false; + sending = 0; - // TODO: Error - doDisconnect(); - return; + received = 0; } - state = CONNECTION_HEADER; - lock.unlock(); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); - connectionHeader(); + enterReceiveLoop(); } -void Connection::doBye() { - if(state != BYE) - return; - - int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR); - if(ret < 0) { - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } +void Connection::handleShutdown(const boost::system::error_code& error) { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - // TODO: Error - doDisconnect(); - return; + if(error) { + // TODO Error } - doDisconnect(); + state = DISCONNECTED; + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); } void Connection::enterReceiveLoop() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(!_isConnected() || _isDisconnecting()) - return; - - if(_isConnecting()) - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - state = PACKET_HEADER; - lock.unlock(); + if(!_isConnected() || _isDisconnecting()) + return; + } - rawReceive(sizeof(Packet::Data), boost::bind(&Connection::packetHeaderReceiveHandler, this, _1, _2)); + rawReceive(sizeof(Packet::Data), boost::bind(&Connection::handleHeaderReceive, this, _1)); } -void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_HEADER) - return; +void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) { + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - if(length != sizeof(Packet::Data)) { - // TODO: Error - doDisconnect(); - return; + header = *reinterpret_cast<const Packet::Data*>(data.data()); } - header = *(const Packet::Data*)data; - if(header.length == 0) { ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId)))); enterReceiveLoop(); } else { - state = PACKET_DATA; - rawReceive(ntohs(header.length), boost::bind(&Connection::packetDataReceiveHandler, this, _1, _2)); + rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, this, _1)); } } -void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_DATA) - return; +void Connection::handleDataReceive(const std::vector<boost::uint8_t> &data) { + { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - if(length != ntohs(header.length)) { - // TODO: Error - doDisconnect(); - return; + Packet packet(ntohs(header.requestId), data.data(), ntohs(header.length)); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, packet)); } - ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId), data, length))); - enterReceiveLoop(); } -void Connection::doReceive() { - if(!isConnected()) - return; - - boost::unique_lock<boost::mutex> lock(receiveLock); - - if(_receiveComplete()) - return; - - ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); +void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + if(error || (bytes_transferred+received) < length) { + Common::Logger::logf(Common::Logger::VERBOSE, "Read error: %s", error.message().c_str()); - if(ret < 0) { - lock.unlock(); - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; - - // TODO: Error + // TODO Error doDisconnect(); return; } - transR.transmitted += ret; - - if(_receiveComplete()) { - // Save data pointer, as transR.notify might start a new reception - uint8_t *data = transR.data; - transR.data = 0; + std::vector<boost::uint8_t> buffer; - lock.unlock(); + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); - transR.notify(data, transR.length); + if(state != CONNECTED || !receiving) + return; - delete [] data; - } - else { - lock.unlock(); + buffer.insert(buffer.end(), receiveBuffer.data(), receiveBuffer.data()+length); } - updateEvents(); -} - -bool Connection::rawReceive(unsigned long length, - const boost::function2<void,const void*,unsigned long> ¬ify) -{ - if(!isConnected()) - return false; - - boost::unique_lock<boost::mutex> lock(receiveLock); - if(!_receiveComplete()) - return false; - - transR.data = new uint8_t[length]; - transR.length = length; - transR.transmitted = 0; - transR.notify = notify; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - lock.unlock(); + receiving = false; + received = received + bytes_transferred - length; - updateEvents(); + if(received) + std::memmove(receiveBuffer.data(), receiveBuffer.data()+length, received); + } - return true; + notify(buffer); } -void Connection::doSend() { - if(!isConnected()) - return; +void Connection::rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - boost::unique_lock<boost::mutex> lock(sendLock); - while(!_sendQueueEmpty()) { - ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, - transS.front().length-transS.front().transmitted); - - if(ret < 0) { - lock.unlock(); + if(!_isConnected()) + return; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - // TODO: Error - doDisconnect(); + if(receiving) return; - } - transS.front().transmitted += ret; + receiving = true; - if(transS.front().transmitted == transS.front().length) { - delete [] transS.front().data; - transS.pop(); + if(length > received) { + boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer.data()+received, receiveBuffer.size()-received), boost::asio::transfer_at_least(length), + boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, + length, notify)); + + return; } } lock.unlock(); - updateEvents(); + handleRead(boost::system::error_code(), 0, length, notify); } -bool Connection::rawSend(const uint8_t *data, unsigned long length) { - if(!isConnected()) - return false; +void Connection::handleWrite(const boost::system::error_code& error, std::size_t) { + { + boost::unique_lock<boost::shared_mutex> lock(connectionLock); - Transmission trans = {length, 0, new uint8_t[length], boost::function2<void,const void*,unsigned long>()}; - std::memcpy(trans.data, data, length); + sending--; - sendLock.lock(); - transS.push(trans); - sendLock.unlock(); - - updateEvents(); + if(state == DISCONNECT && !sending) { + lock.unlock(); + doDisconnect(); + return; + } + } - return true; -} + if(error) { + Common::Logger::logf(Common::Logger::VERBOSE, "Write error: %s", error.message().c_str()); -void Connection::sendReceive(short events) { - if(events & POLLHUP || events & POLLERR) { + // TODO Error doDisconnect(); - return; } +} - switch(state) { - case CONNECT: - handshake(); - return; - case HANDSHAKE: - doHandshake(); - return; - case DISCONNECT: - if(!_sendQueueEmpty()) - break; +void Connection::rawSend(const uint8_t *data, std::size_t length) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - bye(); - return; - case BYE: - doBye(); - return; - default: - break; - } + if(!_isConnected()) + return; - if(events & POLLIN) - doReceive(); + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - if(events & POLLOUT) - doSend(); + sending++; + boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } } bool Connection::send(const Packet &packet) { - stateLock.lock_shared(); - bool err = (!_isConnected() || _isConnecting() || _isDisconnecting()); - stateLock.unlock_shared(); - - if(err) - return false; + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isConnecting() || _isDisconnecting()) + return false; + } - return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + return true; } void Connection::disconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(!_isConnected() || _isDisconnecting()) - return; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isDisconnecting()) + return; - state = DISCONNECT; + state = DISCONNECT; - lock.unlock(); + if(sending) + return; + } - updateEvents(); + doDisconnect(); } void Connection::doDisconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(_isConnected()) { - FdManager::get()->unregisterFd(sock); - - shutdown(sock, SHUT_RDWR); - close(sock); + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - gnutls_deinit(session); - - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); - - state = DISCONNECTED; - } -} - -void Connection::updateEvents() { - receiveLock.lock(); - short events = (_receiveComplete() ? 0 : POLLIN); - receiveLock.unlock(); - - sendLock.lock(); - events |= (_sendQueueEmpty() ? 0 : POLLOUT); - sendLock.unlock(); - - stateLock.lock_shared(); - if(state == HANDSHAKE || state == BYE) - events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); - else if(state == CONNECT || state == DISCONNECT) - events |= POLLOUT; - - FdManager::get()->setFdEvents(sock, events); - stateLock.unlock_shared(); + if(_isConnected()) + socket.async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index a0b95ea..303485d 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -24,155 +24,111 @@ #include "Packet.h" -#include <queue> -#include <string> -#include <gnutls/gnutls.h> -#include <poll.h> +#include <boost/asio.hpp> +#include <boost/asio/ssl.hpp> #include <boost/signal.hpp> -#include <boost/thread/mutex.hpp> -#include <boost/thread/shared_mutex.hpp> -#include <iostream> +#include <boost/thread/shared_mutex.hpp> namespace Mad { namespace Net { -class IPAddress; -class Packet; +class ThreadManager; -class Connection { +class Connection : boost::noncopyable { private: - class StaticInit { - public: - StaticInit() { - gnutls_global_init(); - } + friend class ThreadManager; - ~StaticInit() { - gnutls_global_deinit(); - } - }; - static StaticInit staticInit; + class Buffer { + public: + Buffer(const uint8_t *data0, std::size_t length) : data(new std::vector<uint8_t>(data0, data0+length)), buffer(boost::asio::buffer(*data)) {} - struct Transmission { - unsigned long length; - unsigned long transmitted; + typedef boost::asio::const_buffer value_type; + typedef const boost::asio::const_buffer* const_iterator; - uint8_t *data; + const boost::asio::const_buffer* begin() const { return &buffer; } + const boost::asio::const_buffer* end() const { return &buffer + 1; } - boost::function2<void,const void*,unsigned long> notify; + private: + boost::shared_ptr<std::vector<uint8_t> > data; + boost::asio::const_buffer buffer; }; - boost::mutex receiveLock; - Transmission transR; - boost::mutex sendLock; - std::queue<Transmission> transS; + std::vector<boost::uint8_t> receiveBuffer; + std::size_t received; Packet::Data header; - boost::signal1<void,const Packet&> receiveSignal; + boost::signal1<void, const Packet&> receiveSignal; boost::signal0<void> connectedSignal; boost::signal0<void> disconnectedSignal; - void doHandshake(); - - void packetHeaderReceiveHandler(const void *data, unsigned long length); - void packetDataReceiveHandler(const void *data, unsigned long length); + bool receiving; + unsigned long sending; - void doReceive(); - void doSend(); - - void doBye(); - - void doDisconnect(); + void enterReceiveLoop(); - bool _receiveComplete() const { - return (transR.length == transR.transmitted); - } + void handleHeaderReceive(const std::vector<boost::uint8_t> &data); + void handleDataReceive(const std::vector<boost::uint8_t> &data); - bool _sendQueueEmpty() const {return transS.empty();} + void handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); + void handleWrite(const boost::system::error_code& error, std::size_t); - void bye(); + void handleShutdown(const boost::system::error_code& error); - // Prevent shallow copy - Connection(const Connection &o); - Connection& operator=(const Connection &o); + void rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); + void rawSend(const uint8_t *data, std::size_t length); protected: - struct ConnectionHeader { - uint8_t m; - uint8_t a; - uint8_t d; - uint8_t type; - - uint8_t versionMajor; - uint8_t versionMinor; - uint8_t protVerMin; - uint8_t protVerMax; - }; + static boost::asio::io_service ioService; - boost::shared_mutex stateLock; + boost::shared_mutex connectionLock; enum State { - DISCONNECTED, CONNECT, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE + DISCONNECTED, CONNECT, CONNECTED, DISCONNECT } state; - int sock; - gnutls_session_t session; - gnutls_certificate_credentials_t x509_cred; - IPAddress *peer; + boost::asio::ssl::stream<boost::asio::ip::tcp::socket> socket; + boost::asio::ip::tcp::endpoint peer; - void handshake(); - - virtual void connectionHeader() = 0; - - bool rawReceive(unsigned long length, const boost::function2<void,const void*,unsigned long> ¬ify); - bool rawSend(const uint8_t *data, unsigned long length); - - void enterReceiveLoop(); - - void sendReceive(short events); + void handleHandshake(const boost::system::error_code& error); bool _isConnected() const {return (state != DISCONNECTED);} bool _isConnecting() const { - return (state == CONNECT || state == HANDSHAKE || state == CONNECTION_HEADER); + return (state == CONNECT); } bool _isDisconnecting() const { - return (state == DISCONNECT || state == BYE); + return (state == DISCONNECT); } - void updateEvents(); - - public: - Connection() : state(DISCONNECTED), peer(0) { - transR.length = transR.transmitted = 0; - transR.data = 0; + void doDisconnect(); - gnutls_certificate_allocate_credentials(&x509_cred); - } + Connection(boost::asio::ssl::context &sslContext) : + receiveBuffer(1024*1024), state(DISCONNECTED), socket(ioService, sslContext) {} + public: virtual ~Connection(); bool isConnected() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isConnected(); } bool isConnecting() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isConnecting(); } bool isDisconnecting() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); + boost::shared_lock<boost::shared_mutex> lock(connectionLock); return _isDisconnecting(); } - const gnutls_datum_t* getCertificate() const { + /*const gnutls_datum_t* getCertificate() const { // TODO Thread-safeness return gnutls_certificate_get_ours(session); } @@ -181,16 +137,18 @@ class Connection { // TODO Thread-safeness unsigned int n; return gnutls_certificate_get_peers(session, &n); - } + }*/ - // TODO Thread-safeness - const IPAddress* getPeer() const {return peer;} + boost::asio::ip::tcp::endpoint getPeer() { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); + return peer; + } void disconnect(); bool send(const Packet &packet); - boost::signal1<void,const Packet&>& signalReceive() {return receiveSignal;} + boost::signal1<void, const Packet&>& signalReceive() {return receiveSignal;} boost::signal0<void>& signalConnected() {return connectedSignal;} boost::signal0<void>& signalDisconnected() {return disconnectedSignal;} }; diff --git a/src/Net/Exception.cpp b/src/Net/Exception.cpp index 34b8033..e082948 100644 --- a/src/Net/Exception.cpp +++ b/src/Net/Exception.cpp @@ -20,7 +20,6 @@ #include "Exception.h" #include <cstring> -#include <gnutls/gnutls.h> namespace Mad { namespace Net { @@ -46,8 +45,6 @@ std::string Exception::strerror() const { return ret + "Not implemented"; case INTERNAL_ERRNO: return ret + std::strerror(subCode); - case INTERNAL_GNUTLS: - return ret + "GnuTLS error: " + gnutls_strerror(subCode); case INVALID_ADDRESS: return ret + "Invalid address"; case ALREADY_IDENTIFIED: diff --git a/src/Net/Exception.h b/src/Net/Exception.h index 48e86d1..8522528 100644 --- a/src/Net/Exception.h +++ b/src/Net/Exception.h @@ -29,7 +29,7 @@ class Exception { public: enum ErrorCode { SUCCESS = 0x0000, UNEXPECTED_PACKET = 0x0001, INVALID_ACTION = 0x0002, NOT_AVAILABLE = 0x0003, NOT_FINISHED = 0x0004, NOT_IMPLEMENTED = 0x0005, - INTERNAL_ERRNO = 0x0010, INTERNAL_GNUTLS = 0x0011, + INTERNAL_ERRNO = 0x0010, INVALID_ADDRESS = 0x0020, ALREADY_IDENTIFIED = 0x0030, UNKNOWN_DAEMON = 0x0031 }; diff --git a/src/Net/FdManager.cpp b/src/Net/FdManager.cpp deleted file mode 100644 index d8faef4..0000000 --- a/src/Net/FdManager.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* - * FdManager.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "FdManager.h" -#include "ThreadManager.h" - -#include <signal.h> -#include <unistd.h> -#include <sys/fcntl.h> - - -namespace Mad { -namespace Net { - -FdManager FdManager::fdManager; - - -FdManager::FdManager() : running(false) { - pipe(interruptPipe); - - int flags = fcntl(interruptPipe[0], F_GETFL, 0); - fcntl(interruptPipe[0], F_SETFL, flags | O_NONBLOCK); - - flags = fcntl(interruptPipe[1], F_GETFL, 0); - fcntl(interruptPipe[1], F_SETFL, flags | O_NONBLOCK); - - registerFd(interruptPipe[0], boost::bind(&FdManager::readInterrupt, this), POLLIN); -} - -FdManager::~FdManager() { - unregisterFd(interruptPipe[0]); - - close(interruptPipe[0]); - close(interruptPipe[1]); -} - - -bool FdManager::registerFd(int fd, const boost::function1<void, short> &handler, short events) { - struct pollfd pollfd = {fd, events, 0}; - - boost::lock(handlerLock, eventLock); - pollfds.insert(std::make_pair(fd, pollfd)); - - bool ret = handlers.insert(std::make_pair(fd, handler)).second; - - eventLock.unlock(); - handlerLock.unlock(); - - interrupt(); - - return ret; -} - -bool FdManager::unregisterFd(int fd) { - boost::lock(handlerLock, eventLock); - pollfds.erase(fd); - bool ret = handlers.erase(fd); - eventLock.unlock(); - handlerLock.unlock(); - - interrupt(); - - return ret; -} - -bool FdManager::setFdEvents(int fd, short events) { - boost::unique_lock<boost::shared_mutex> lock(eventLock); - - std::map<int, struct pollfd>::iterator pollfd = pollfds.find(fd); - - if(pollfd == pollfds.end()) - return false; - - if(pollfd->second.events != events) { - pollfd->second.events = events; - interrupt(); - } - - return true; -} - -short FdManager::getFdEvents(int fd) { - boost::shared_lock<boost::shared_mutex> lock(eventLock); - - std::map<int, struct pollfd>::const_iterator pollfd = pollfds.find(fd); - - if(pollfd == pollfds.end()) - return -1; - - return pollfd->second.events; -} - -void FdManager::readInterrupt() { - char buf[20]; - - while(read(interruptPipe[0], buf, sizeof(buf)) > 0) {} -} - -void FdManager::interrupt() { - char buf = 0; - - write(interruptPipe[1], &buf, sizeof(buf)); -} - -void FdManager::ioThread() { - runLock.lock(); - running = true; - runLock.unlock_and_lock_shared(); - - while(running) { - runLock.unlock_shared(); - - handlerLock.lock_shared(); - eventLock.lock_shared(); - readInterrupt(); - - size_t count = pollfds.size(); - struct pollfd *fdarray = new struct pollfd[count]; - - std::map<int, struct pollfd>::iterator pollfd = pollfds.begin(); - - for(size_t n = 0; n < count; ++n) { - fdarray[n] = pollfd->second; - ++pollfd; - } - - eventLock.unlock_shared(); - handlerLock.unlock_shared(); - - if(poll(fdarray, count, -1) > 0) { - handlerLock.lock_shared(); - - std::queue<boost::function0<void> > calls; - - for(size_t n = 0; n < count; ++n) { - if(fdarray[n].revents) - calls.push(boost::bind(handlers[fdarray[n].fd], fdarray[n].revents)); - } - - handlerLock.unlock_shared(); - - while(!calls.empty()) { - calls.front()(); - calls.pop(); - } - - } - - delete [] fdarray; - - runLock.lock_shared(); - } - - runLock.unlock_shared(); -} - -} -} diff --git a/src/Net/FdManager.h b/src/Net/FdManager.h deleted file mode 100644 index 1cb95bc..0000000 --- a/src/Net/FdManager.h +++ /dev/null @@ -1,77 +0,0 @@ -/* - * FdManager.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_FDMANAGER_H_ -#define MAD_NET_FDMANAGER_H_ - -#include <map> -#include <poll.h> - -#include <boost/function.hpp> -#include <boost/thread/shared_mutex.hpp> - -namespace Mad { -namespace Net { - -class ThreadManager; - -class FdManager { - private: - friend class ThreadManager; - - static FdManager fdManager; - - boost::shared_mutex runLock, handlerLock, eventLock; - bool running; - - std::map<int, struct pollfd> pollfds; - std::map<int, boost::function1<void, short> > handlers; - - int interruptPipe[2]; - - void readInterrupt(); - void interrupt(); - - FdManager(); - - void ioThread(); - void stopIOThread() { - runLock.lock(); - running = false; - runLock.unlock(); - - interrupt(); - } - - public: - virtual ~FdManager(); - - static FdManager *get() {return &fdManager;} - - bool registerFd(int fd, const boost::function1<void, short> &handler, short events = 0); - bool unregisterFd(int fd); - - bool setFdEvents(int fd, short events); - short getFdEvents(int fd); -}; - -} -} - -#endif /* MAD_NET_FDMANAGER_H_ */ diff --git a/src/Net/IPAddress.cpp b/src/Net/IPAddress.cpp deleted file mode 100644 index eb9d3be..0000000 --- a/src/Net/IPAddress.cpp +++ /dev/null @@ -1,92 +0,0 @@ -/* - * IPAddress.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "IPAddress.h" - -#include <cstdlib> - -namespace Mad { -namespace Net { - -IPAddress::IPAddress(uint16_t port0) : addr(INADDR_ANY), port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - sa.sin_addr.s_addr = INADDR_ANY; -} - -IPAddress::IPAddress(uint32_t address, uint16_t port0) : addr(address), port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - sa.sin_addr.s_addr = htonl(addr); -} - -IPAddress::IPAddress(const std::string &address) throw(Exception) { - std::string ip; - size_t pos = address.find_first_of(':'); - - if(pos == std::string::npos) { - ip = address; - // TODO Default port - port = 6666; - } - else { - ip = address.substr(0, pos); - - char *endptr; - port = std::strtol(address.substr(pos+1).c_str(), &endptr, 10); - if(*endptr != 0 || port == 0) - throw Exception(Exception::INVALID_ADDRESS); - } - - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - - if(ip == "*") - sa.sin_addr.s_addr = INADDR_ANY; - else if(!inet_pton(AF_INET, ip.c_str(), &sa.sin_addr)) - throw Exception(Exception::INVALID_ADDRESS); - - addr = ntohl(sa.sin_addr.s_addr); -} - -IPAddress::IPAddress(const std::string &address, uint16_t port0) throw(Exception) : port(port0) { - sa.sin_family = AF_INET; - sa.sin_port = htons(port); - - if(!inet_pton(AF_INET, address.c_str(), &sa.sin_addr)) - throw Exception(Exception::INVALID_ADDRESS); - - addr = ntohl(sa.sin_addr.s_addr); -} - -IPAddress::IPAddress(const struct sockaddr_in &address) : sa(address) { - port = ntohs(sa.sin_port); - addr = ntohl(sa.sin_addr.s_addr); -} - -std::string IPAddress::getAddressString() const { - char buf[INET_ADDRSTRLEN]; - uint32_t address = htonl(addr); - - inet_ntop(AF_INET, &address, buf, sizeof(buf)); - return std::string(buf); -} - -} -} diff --git a/src/Net/IPAddress.h b/src/Net/IPAddress.h deleted file mode 100644 index 3541891..0000000 --- a/src/Net/IPAddress.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * IPAddress.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_IPADDRESS_H_ -#define MAD_NET_IPADDRESS_H_ - -#include "Exception.h" - -#include <string> -#include <arpa/inet.h> -#include <stdint.h> - -namespace Mad { -namespace Net { - -class IPAddress { - private: - uint32_t addr; - uint16_t port; - struct sockaddr_in sa; - - public: - // TODO Default port - IPAddress(uint16_t port0 = 6666); - IPAddress(uint32_t address, uint16_t port0); - IPAddress(const std::string &address) throw(Exception); - IPAddress(const std::string &address, uint16_t port0) throw(Exception); - IPAddress(const struct sockaddr_in &address); - - uint32_t getAddress() const {return addr;} - uint16_t getPort() const {return port;} - - std::string getAddressString() const; - - struct sockaddr* getSockAddr() {return (struct sockaddr*)&sa;} - socklen_t getSockAddrLength() const {return sizeof(sa);} -}; - -} -} - -#endif /*MAD_NET_IPADDRESS_H_*/ diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 11cbaf5..6f49a74 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -18,8 +18,6 @@ */ #include "Listener.h" -#include "FdManager.h" -#include "ServerConnection.h" #include <Common/Logger.h> @@ -30,26 +28,29 @@ namespace Mad { namespace Net { -void Listener::acceptHandler(int) { - int sd; - struct sockaddr_in sa; - socklen_t addrlen = sizeof(sa); +void Listener::handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con) { + if(error) + return; + + { + boost::lock_guard<boost::shared_mutex> lock(con->connectionLock); + con->state = ServerConnection::CONNECT; - while((sd = accept(sock, (struct sockaddr*)&sa, &addrlen)) >= 0) { - ServerConnection *con = new ServerConnection(sd, IPAddress(sa), dh_params, x905CertFile, x905KeyFile); - boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::connectHandler, this, con)); - boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::disconnectHandler, this, con)); + boost::signals::connection con1 = con->signalConnected().connect(boost::bind(&Listener::handleConnect, this, con)); + boost::signals::connection con2 = con->signalDisconnected().connect(boost::bind(&Listener::handleDisconnect, this, con)); connections.insert(std::make_pair(con, std::make_pair(con1, con2))); - addrlen = sizeof(sa); + con->socket.async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&ServerConnection::handleHandshake, con, boost::asio::placeholders::error)); } -} + con.reset(new ServerConnection(sslContext)); + acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); +} -void Listener::connectHandler(ServerConnection *con) { - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); +void Listener::handleConnect(boost::shared_ptr<ServerConnection> con) { + std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); if(it == connections.end()) return; @@ -62,67 +63,33 @@ void Listener::connectHandler(ServerConnection *con) { signal(con); } -void Listener::disconnectHandler(ServerConnection *con) { - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> >::iterator it = connections.find(con); - - if(it == connections.end()) - return; - - delete it->first; - connections.erase(it); +void Listener::handleDisconnect(boost::shared_ptr<ServerConnection> con) { + connections.erase(con); } -Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0) throw(Exception) -: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0) { - gnutls_dh_params_init(&dh_params); - gnutls_dh_params_generate2(dh_params, 768); +Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, + const boost::asio::ip::tcp::endpoint &address0) throw(Exception) +: x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0), acceptor(Connection::ioService, address), +sslContext(Connection::ioService, boost::asio::ssl::context::sslv23) +{ + sslContext.set_options(boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use); + sslContext.use_certificate_chain_file(x905CertFile0); + sslContext.use_private_key_file(x905KeyFile0, boost::asio::ssl::context::pem); - sock = socket(PF_INET, SOCK_STREAM, 0); - if(sock < 0) - throw Exception("socket()", Exception::INTERNAL_ERRNO, errno); - // Set non-blocking flag - int flags = fcntl(sock, F_GETFL, 0); - - if(flags < 0) { - close(sock); - - throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno); - } - - fcntl(sock, F_SETFL, flags | O_NONBLOCK); - - // Don't linger - struct linger linger = {1, 0}; - setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); - - if(bind(sock, address.getSockAddr(), address.getSockAddrLength()) < 0) { - close(sock); - - throw Exception("bind()", Exception::INTERNAL_ERRNO, errno); - } - - if(listen(sock, 64) < 0) { - close(sock); - - throw Exception("listen()", Exception::INTERNAL_ERRNO, errno); - } - - FdManager::get()->registerFd(sock, boost::bind(&Listener::acceptHandler, this, _1), POLLIN); + boost::shared_ptr<ServerConnection> con(new ServerConnection(sslContext)); + acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); } Listener::~Listener() { - for(std::map<ServerConnection*,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) { + for(std::map<boost::shared_ptr<ServerConnection>,std::pair<boost::signals::connection, boost::signals::connection> >::iterator con = connections.begin(); con != connections.end(); ++con) { con->first->disconnect(); - delete con->first; + // TODO wait... } - - shutdown(sock, SHUT_RDWR); - close(sock); - - gnutls_dh_params_deinit(dh_params); } } diff --git a/src/Net/Listener.h b/src/Net/Listener.h index 26dffab..0833cdf 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -20,46 +20,45 @@ #ifndef MAD_NET_LISTENER_H_ #define MAD_NET_LISTENER_H_ -#include "IPAddress.h" - -#include <gnutls/gnutls.h> #include <map> #include <string> -#include <boost/signal.hpp> +#include "Connection.h" +#include "Exception.h" namespace Mad { namespace Net { -class ServerConnection; - // TODO XXX Thread-safeness XXX -class Listener { +class Listener : boost::noncopyable { private: - std::string x905CertFile, x905KeyFile; - IPAddress address; - int sock; + class ServerConnection : public Connection { + public: + friend class Listener; - gnutls_dh_params_t dh_params; + ServerConnection(boost::asio::ssl::context &sslContext) : Connection(sslContext) {} + }; - std::map<ServerConnection*, std::pair<boost::signals::connection, boost::signals::connection> > connections; + std::string x905CertFile, x905KeyFile; + boost::asio::ip::tcp::endpoint address; + boost::asio::ip::tcp::acceptor acceptor; + boost::asio::ssl::context sslContext; - boost::signal1<void, ServerConnection*> signal; + std::map<boost::shared_ptr<ServerConnection>, std::pair<boost::signals::connection, boost::signals::connection> > connections; - void acceptHandler(int); + boost::signal1<void, boost::shared_ptr<Connection> > signal; - void connectHandler(ServerConnection *con); - void disconnectHandler(ServerConnection *con); + void handleAccept(const boost::system::error_code &error, boost::shared_ptr<ServerConnection> con); - // Prevent shallow copy - Listener(const Listener &o); - Listener& operator=(const Listener &o); + void handleConnect(boost::shared_ptr<ServerConnection> con); + void handleDisconnect(boost::shared_ptr<ServerConnection> con); public: - Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0 = IPAddress()) throw(Exception); + Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, + const boost::asio::ip::tcp::endpoint &address0 = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 6666)) throw(Exception); virtual ~Listener(); - boost::signal1<void, ServerConnection*>& signalNewConnection() {return signal;} + boost::signal1<void, boost::shared_ptr<Connection> >& signalNewConnection() {return signal;} }; } diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp deleted file mode 100644 index 1f01ce5..0000000 --- a/src/Net/ServerConnection.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * ServerConnection.cpp - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "ServerConnection.h" -#include "FdManager.h" -#include "IPAddress.h" - -#include <boost/thread/locks.hpp> - -#include <cstring> -#include <cerrno> -#include <sys/socket.h> -#include <fcntl.h> - -namespace Mad { -namespace Net { - -void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { - if(length != sizeof(ConnectionHeader)) - // Error... disconnect - return; - - const ConnectionHeader *header = (const ConnectionHeader*)data; - - if(header->m != 'M' || header->a != 'A' || header->d != 'D') - // Error... disconnect - return; - - if(header->protVerMin > 1 || header->protVerMax < 1) - // Unsupported protocol... disconnect - return; - - if(header->type == 'C') - daemon = false; - else if(header->type == 'D') - daemon = true; - else - // Error... disconnect - return; - - ConnectionHeader header2 = {'M', 'A', 'D', 0, 0, 1, 1, 0}; - - enterReceiveLoop(); - - rawSend((uint8_t*)&header2, sizeof(header2)); -} - -ServerConnection::ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905CertFile, const std::string &x905KeyFile) -: daemon(false) { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - sock = sock0; - - peer = new IPAddress(address); - - gnutls_certificate_set_dh_params(x509_cred, dh_params); - gnutls_certificate_set_x509_key_file(x509_cred, x905CertFile.c_str(), x905KeyFile.c_str(), GNUTLS_X509_FMT_PEM); - - gnutls_init(&session, GNUTLS_SERVER); - gnutls_set_default_priority(session); - gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); - gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - - FdManager::get()->registerFd(sock, boost::bind(&ServerConnection::sendReceive, this, _1)); - - state = CONNECT; - - lock.unlock(); - - updateEvents(); -} - -} -} diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h deleted file mode 100644 index d52cd7c..0000000 --- a/src/Net/ServerConnection.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * ServerConnection.h - * - * Copyright (C) 2008 Matthias Schiffer <matthias@gamezock.de> - * - * This program is free software: you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, but - * WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * See the GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef MAD_NET_SERVERCONNECTION_H_ -#define MAD_NET_SERVERCONNECTION_H_ - -#include "Connection.h" -#include <string> - -namespace Mad { -namespace Net { - -class Listener; - -class ServerConnection : public Connection { - friend class Listener; - - private: - IPAddress *peer; - - bool daemon; - - gnutls_anon_server_credentials_t anoncred; - - void connectionHeaderReceiveHandler(const void *data, unsigned long length); - - protected: - ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905certFile, const std::string &x905keyFile); - - virtual void connectionHeader() { - rawReceive(sizeof(ConnectionHeader), boost::bind(&ServerConnection::connectionHeaderReceiveHandler, this, _1, _2)); - } - - public: - bool isDaemonConnection() const {return daemon;} -}; - -} -} - -#endif /*MAD_NET_SERVERCONNECTION_H_*/ diff --git a/src/Net/ThreadManager.cpp b/src/Net/ThreadManager.cpp index 71a754e..0fb0716 100644 --- a/src/Net/ThreadManager.cpp +++ b/src/Net/ThreadManager.cpp @@ -18,11 +18,13 @@ */ #include "ThreadManager.h" -#include "FdManager.h" +#include "Connection.h" #include <Common/Logger.h> #include <Common/LogManager.h> +#include <boost/bind.hpp> + #include <fcntl.h> namespace Mad { @@ -97,10 +99,12 @@ void ThreadManager::doInit() { threadLock.lock(); + ioWorker.reset(new boost::asio::io_service::work(Connection::ioService)); + mainThreadId = boost::this_thread::get_id(); - workerThread = new boost::thread(std::mem_fun(&ThreadManager::workerFunc), this); - loggerThread = new boost::thread(std::mem_fun(&Common::LogManager::loggerThread), Common::LogManager::get()); - ioThread = new boost::thread(std::mem_fun(&FdManager::ioThread), FdManager::get()); + workerThread = new boost::thread(&ThreadManager::workerFunc, this); + loggerThread = new boost::thread(&Common::LogManager::loggerThread, Common::LogManager::get()); + ioThread = new boost::thread((std::size_t(boost::asio::io_service::*)())&boost::asio::io_service::run, &Connection::ioService); threadLock.unlock(); } @@ -128,7 +132,7 @@ void ThreadManager::doDeinit() { threads.join_all(); // IO thread is next - FdManager::get()->stopIOThread(); + ioWorker.reset(); ioThread->join(); delete ioThread; diff --git a/src/Net/ThreadManager.h b/src/Net/ThreadManager.h index fd903af..2c57747 100644 --- a/src/Net/ThreadManager.h +++ b/src/Net/ThreadManager.h @@ -27,7 +27,8 @@ #include <queue> #include <set> -#include <boost/function.hpp> +#include <boost/asio.hpp> + #include <boost/thread/thread.hpp> #include <boost/thread/condition_variable.hpp> #include <boost/thread/locks.hpp> @@ -50,11 +51,14 @@ class ThreadManager : public Common::Initializable { boost::condition_variable workCond; std::queue<boost::function0<void> > work; + boost::scoped_ptr<boost::asio::io_service::work> ioWorker; + static ThreadManager threadManager; ThreadManager() {} void workerFunc(); + void ioFunc(); void threadFinished(boost::thread *thread) { threadLock.lock(); |