diff options
-rw-r--r-- | src/Common/AuthManager.cpp | 3 | ||||
-rw-r--r-- | src/Common/ClientConnection.cpp | 2 | ||||
-rw-r--r-- | src/Common/ClientConnection.h | 3 | ||||
-rw-r--r-- | src/Net/ClientConnection.cpp | 4 | ||||
-rw-r--r-- | src/Net/ClientConnection.h | 13 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 30 | ||||
-rw-r--r-- | src/Net/Connection.h | 45 | ||||
-rw-r--r-- | src/Net/Listener.cpp | 24 | ||||
-rw-r--r-- | src/Net/Listener.h | 1 | ||||
-rw-r--r-- | src/Server/ConnectionManager.cpp | 16 | ||||
-rw-r--r-- | src/Server/ConnectionManager.h | 2 |
11 files changed, 82 insertions, 61 deletions
diff --git a/src/Common/AuthManager.cpp b/src/Common/AuthManager.cpp index 82fbda5..d21909d 100644 --- a/src/Common/AuthManager.cpp +++ b/src/Common/AuthManager.cpp @@ -18,9 +18,10 @@ */ #include "AuthManager.h" - #include "AuthBackend.h" +#include <boost/thread/locks.hpp> + namespace Mad { namespace Common { diff --git a/src/Common/ClientConnection.cpp b/src/Common/ClientConnection.cpp index 51b23fa..4182806 100644 --- a/src/Common/ClientConnection.cpp +++ b/src/Common/ClientConnection.cpp @@ -24,7 +24,7 @@ namespace Mad { namespace Common { -ClientConnection::ClientConnection(Core::Application *application) : Connection(application), connection(new Net::ClientConnection(application)) { +ClientConnection::ClientConnection(Core::Application *application) : Connection(application), connection(Net::ClientConnection::create(application)) { connection->connectSignalReceive(boost::bind(&ClientConnection::receive, this, _1)); } diff --git a/src/Common/ClientConnection.h b/src/Common/ClientConnection.h index b50f163..a02c461 100644 --- a/src/Common/ClientConnection.h +++ b/src/Common/ClientConnection.h @@ -37,14 +37,13 @@ namespace Common { class MAD_COMMON_EXPORT ClientConnection : public Connection { private: - Net::ClientConnection *connection; + boost::shared_ptr<Net::ClientConnection> connection; protected: virtual bool send(const Net::Packet &packet); public: ClientConnection(Core::Application *application); - virtual ~ClientConnection() {} void connect(const boost::asio::ip::tcp::endpoint &address) throw(Core::Exception); diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 2228b7a..10a03f1 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -31,7 +31,7 @@ void ClientConnection::handleConnect(const boost::system::error_code& error) { boost::lock_guard<boost::shared_mutex> lock(connectionLock); - socket->async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error)); + socket.async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, thisPtr.lock(), boost::asio::placeholders::error)); } void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Core::Exception) { @@ -45,7 +45,7 @@ void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) th peer = address; _setState(CONNECT); - socket->lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error)); + socket.lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, boost::dynamic_pointer_cast<ClientConnection>(thisPtr.lock()), boost::asio::placeholders::error)); } } diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index d29d6ae..64203d7 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -35,9 +35,18 @@ class MAD_NET_EXPORT ClientConnection : public Connection { private: void handleConnect(const boost::system::error_code& error); + ClientConnection(Core::Application *application, boost::shared_ptr<boost::asio::ssl::context> context) : Connection(application, context) {} + public: - ClientConnection(Core::Application *application) : Connection(application) { - context.set_verify_mode(boost::asio::ssl::context::verify_none); + static boost::shared_ptr<ClientConnection> create(Core::Application *application) { + boost::shared_ptr<boost::asio::ssl::context> context(new boost::asio::ssl::context(application->getIOService(), boost::asio::ssl::context::sslv23)); + context->set_verify_mode(boost::asio::ssl::context::verify_none); + + boost::shared_ptr<ClientConnection> connection(new ClientConnection(application, context)); + + connection->thisPtr = connection; + + return connection; } void connect(const boost::asio::ip::tcp::endpoint &address) throw(Core::Exception); diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 036c3d8..fc917d4 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -29,8 +29,8 @@ namespace Net { Connection::~Connection() { - if(_isConnected()) - doDisconnect(); + doDisconnect(); + waitWhileConnected(); } void Connection::handleHandshake(const boost::system::error_code& error) { @@ -77,7 +77,7 @@ void Connection::enterReceiveLoop() { return; } - rawReceive(sizeof(Packet::Header), boost::bind(&Connection::handleHeaderReceive, this, _1)); + rawReceive(sizeof(Packet::Header), boost::bind(&Connection::handleHeaderReceive, thisPtr.lock(), _1)); } void Connection::handleHeaderReceive(const boost::shared_array<boost::uint8_t> &data) { @@ -93,7 +93,7 @@ void Connection::handleHeaderReceive(const boost::shared_array<boost::uint8_t> & enterReceiveLoop(); } else { - rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, this, _1)); + rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, thisPtr.lock(), _1)); } } @@ -109,7 +109,10 @@ void Connection::handleDataReceive(const boost::shared_array<boost::uint8_t> &da void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const boost::shared_array<boost::uint8_t>& > ¬ify) { if(error || (bytes_transferred+received) < length) { - application->logf(Core::LoggerBase::LOG_VERBOSE, "Read error: %s", error.message().c_str()); + if(error == boost::system::errc::operation_canceled) + return; + + application->logf(Core::LoggerBase::LOG_DEFAULT, "Read error: %s", error.message().c_str()); // TODO Error doDisconnect(); @@ -155,8 +158,8 @@ void Connection::rawReceive(std::size_t length, const boost::function1<void, con receiving = true; if(length > received) { - boost::asio::async_read(*socket, boost::asio::buffer(receiveBuffer->data()+received, receiveBuffer->size()-received), boost::asio::transfer_at_least(length), - boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, + boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer->data()+received, receiveBuffer->size()-received), boost::asio::transfer_at_least(length), + boost::bind(&Connection::handleRead, thisPtr.lock(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, length, notify)); return; @@ -169,6 +172,9 @@ void Connection::rawReceive(std::size_t length, const boost::function1<void, con } void Connection::handleWrite(const boost::system::error_code& error, std::size_t) { + if(error) + application->logf(Core::LoggerBase::LOG_VERBOSE, "Write error: %s", error.message().c_str()); + { boost::unique_lock<boost::shared_mutex> lock(connectionLock); @@ -182,8 +188,6 @@ void Connection::handleWrite(const boost::system::error_code& error, std::size_t } if(error) { - application->logf(Core::LoggerBase::LOG_VERBOSE, "Write error: %s", error.message().c_str()); - // TODO Error doDisconnect(); } @@ -199,7 +203,7 @@ void Connection::rawSend(const boost::uint8_t *data, std::size_t length) { boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); sending++; - boost::asio::async_write(*socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, thisPtr.lock(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } } @@ -229,11 +233,5 @@ void Connection::disconnect() { doDisconnect(); } -void Connection::doDisconnect() { - boost::lock_guard<boost::shared_mutex> lock(connectionLock); - - socket->async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); -} - } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 19ee826..51f40b0 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -38,15 +38,13 @@ namespace Mad { namespace Net { class Listener; -class ThreadManager; class MAD_NET_EXPORT Connection : boost::noncopyable { protected: friend class Listener; - friend class ThreadManager; enum State { - DISCONNECTED, CONNECT, CONNECTED, DISCONNECT + DISCONNECTED, CONNECT, CONNECTED, DISCONNECT, SHUTDOWN }; private: @@ -83,10 +81,6 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { bool receiving; unsigned long sending; - void _initSocket() { - socket.reset(new boost::asio::ssl::stream<boost::asio::ip::tcp::socket>(application->getIOService(), context)); - } - void enterReceiveLoop(); void handleHeaderReceive(const boost::shared_array<boost::uint8_t> &data); @@ -101,10 +95,12 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { void rawSend(const boost::uint8_t *data, std::size_t length); protected: + boost::weak_ptr<Connection> thisPtr; + boost::shared_mutex connectionLock; - boost::asio::ssl::context context; - boost::scoped_ptr<boost::asio::ssl::stream<boost::asio::ip::tcp::socket> > socket; + boost::shared_ptr<boost::asio::ssl::context> context; + boost::asio::ssl::stream<boost::asio::ip::tcp::socket> socket; boost::asio::ip::tcp::endpoint peer; void handleHandshake(const boost::system::error_code& error); @@ -115,25 +111,38 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { } bool _isDisconnecting() const { - return (state == DISCONNECT); + return (state == DISCONNECT || state == SHUTDOWN); } void _setState(State newState) { state = newState; - if(_isConnected() && !socket.get()) - _initSocket(); - else if(!_isConnected() && socket.get()) - socket.reset(); - stateChanged.notify_all(); } - void doDisconnect(); + void doDisconnect() { + boost::unique_lock<boost::shared_mutex> lock(connectionLock); + + if(_isConnected() && state != SHUTDOWN) { + _setState(SHUTDOWN); + boost::system::error_code error; + socket.lowest_layer().cancel(error); - Connection(Core::Application *application0) : + socket.async_shutdown(boost::bind(&Connection::handleShutdown, thisPtr.lock(), boost::asio::placeholders::error)); + } + } + + Connection(Core::Application *application0, boost::shared_ptr<boost::asio::ssl::context> context0) : application(application0), state(DISCONNECTED), receiveBuffer(new boost::array<boost::uint8_t, 1024*1024>), receiveSignal(application), connectedSignal(application), - disconnectedSignal(application), context(application->getIOService(), boost::asio::ssl::context::sslv23) {} + disconnectedSignal(application), context(context0), socket(application->getIOService(), *context) {} + + static boost::shared_ptr<Connection> create(Core::Application *application, boost::shared_ptr<boost::asio::ssl::context> context) { + boost::shared_ptr<Connection> connection(new Connection(application, context)); + + connection->thisPtr = connection; + + return connection; + } public: virtual ~Connection(); diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 6187a1e..b3974d1 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -27,17 +27,9 @@ namespace Mad { namespace Net { void Listener::accept() { - boost::shared_ptr<Connection> con(new Connection(application)); - - con->context.set_options(boost::asio::ssl::context::default_workarounds - | boost::asio::ssl::context::no_sslv2 - | boost::asio::ssl::context::single_dh_use); - con->context.use_certificate_chain_file(x905CertFile); - con->context.use_private_key_file(x905KeyFile, boost::asio::ssl::context::pem); - - con->_initSocket(); + boost::shared_ptr<Connection> con(Connection::create(application, context)); - acceptor.async_accept(con->socket->lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); + acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); } void Listener::handleAccept(const boost::system::error_code &error, boost::shared_ptr<Connection> con) { @@ -56,7 +48,7 @@ void Listener::handleAccept(const boost::system::error_code &error, boost::share connections.insert(std::make_pair(con, std::make_pair(con1, con2))); - con->socket->async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&Connection::handleHandshake, con, boost::asio::placeholders::error)); + con->socket.async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&Connection::handleHandshake, con, boost::asio::placeholders::error)); } accept(); @@ -89,7 +81,15 @@ void Listener::handleDisconnect(boost::shared_ptr<Connection> con) { Listener::Listener(Core::Application *application0, const std::string &x905CertFile0, const std::string &x905KeyFile0, const boost::asio::ip::tcp::endpoint &address0) throw(Core::Exception) : application(application0), x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0), -acceptor(application->getIOService(), address), signal(application) { +context(new boost::asio::ssl::context(application->getIOService(), boost::asio::ssl::context::sslv23)), +acceptor(application->getIOService(), address), signal(application) +{ + context->set_options(boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use); + context->use_certificate_chain_file(x905CertFile); + context->use_private_key_file(x905KeyFile, boost::asio::ssl::context::pem); + accept(); } diff --git a/src/Net/Listener.h b/src/Net/Listener.h index 64572c0..598ea2f 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -39,6 +39,7 @@ class MAD_NET_EXPORT Listener : private boost::noncopyable { std::string x905CertFile, x905KeyFile; boost::asio::ip::tcp::endpoint address; + boost::shared_ptr<boost::asio::ssl::context> context; boost::asio::ip::tcp::acceptor acceptor; std::map<boost::shared_ptr<Connection>, std::pair<Core::Signals::Connection, Core::Signals::Connection> > connections; diff --git a/src/Server/ConnectionManager.cpp b/src/Server/ConnectionManager.cpp index b9490bc..27b65c2 100644 --- a/src/Server/ConnectionManager.cpp +++ b/src/Server/ConnectionManager.cpp @@ -194,19 +194,23 @@ void ConnectionManager::configFinished() { void ConnectionManager::handleNewConnection(boost::shared_ptr<Net::Connection> con) { boost::shared_ptr<ServerConnection> connection(new ServerConnection(application, con)); - con->connectSignalDisconnected(boost::bind(&ConnectionManager::handleDisconnect, this, connection)); + con->connectSignalDisconnected(boost::bind(&ConnectionManager::handleDisconnect, this, boost::weak_ptr<ServerConnection>(connection))); connections.insert(connection); application->getRequestManager()->registerConnection(connection.get()); } -void ConnectionManager::handleDisconnect(boost::shared_ptr<ServerConnection> con) { - if(con->getHostInfo()) - updateState(con->getHostInfo(), Common::HostInfo::INACTIVE); +void ConnectionManager::handleDisconnect(boost::weak_ptr<ServerConnection> con) { + boost::shared_ptr<ServerConnection> connection = con.lock(); + if(!connection) + return; + + if(connection->getHostInfo()) + updateState(connection->getHostInfo(), Common::HostInfo::INACTIVE); - connections.erase(con); + connections.erase(connection); - application->getRequestManager()->unregisterConnection(con.get()); + application->getRequestManager()->unregisterConnection(connection.get()); } ConnectionManager::ConnectionManager(Application *application0) : application(application0), diff --git a/src/Server/ConnectionManager.h b/src/Server/ConnectionManager.h index 9638f38..057e73e 100644 --- a/src/Server/ConnectionManager.h +++ b/src/Server/ConnectionManager.h @@ -127,7 +127,7 @@ class MAD_SERVER_EXPORT ConnectionManager : public Core::Configurable, private b void updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state); void handleNewConnection(boost::shared_ptr<Net::Connection> con); - void handleDisconnect(boost::shared_ptr<ServerConnection> con); + void handleDisconnect(boost::weak_ptr<ServerConnection> con); ConnectionManager(Application *application0); ~ConnectionManager(); |