diff options
Diffstat (limited to 'src/Net/Connection.cpp')
-rw-r--r-- | src/Net/Connection.cpp | 364 |
1 files changed, 114 insertions, 250 deletions
diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 4e5029f..b9691cb 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -18,361 +18,225 @@ */ #include "Connection.h" -#include "FdManager.h" -#include "IPAddress.h" #include "ThreadManager.h" -#include <cstring> -#include <sys/socket.h> +#include <Common/Logger.h> +#include <cstring> #include <boost/bind.hpp> namespace Mad { namespace Net { - -Connection::StaticInit Connection::staticInit; +boost::asio::io_service Connection::ioService; Connection::~Connection() { if(_isConnected()) doDisconnect(); - - if(transR.data) - delete [] transR.data; - - while(!_sendQueueEmpty()) { - delete [] transS.front().data; - transS.pop(); - } - - gnutls_certificate_free_credentials(x509_cred); - - if(peer) - delete peer; -} - -void Connection::handshake() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != CONNECT) - return; - - state = HANDSHAKE; - lock.unlock(); - - doHandshake(); } -void Connection::bye() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(state != DISCONNECT) - return; - - state = BYE; - lock.unlock(); +void Connection::handleHandshake(const boost::system::error_code& error) { + if(error) { + Common::Logger::logf("Error: %s", error.message().c_str()); - doBye(); -} - -void Connection::doHandshake() { - boost::shared_lock<boost::shared_mutex> lock(stateLock); - if(state != HANDSHAKE) + // TODO Error handling + doDisconnect(); return; + } - int ret = gnutls_handshake(session); - if(ret < 0) { - lock.unlock(); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + state = CONNECTED; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } + receiving = false; + sending = 0; - // TODO: Error - doDisconnect(); - return; + received = 0; } - state = CONNECTION_HEADER; - lock.unlock(); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); - connectionHeader(); + enterReceiveLoop(); } -void Connection::doBye() { - if(state != BYE) - return; - - int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR); - if(ret < 0) { - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { - updateEvents(); - return; - } +void Connection::handleShutdown(const boost::system::error_code& error) { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - // TODO: Error - doDisconnect(); - return; + if(error) { + // TODO Error } - doDisconnect(); + state = DISCONNECTED; + ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); } void Connection::enterReceiveLoop() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(!_isConnected() || _isDisconnecting()) - return; - - if(_isConnecting()) - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &connectedSignal)); + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - state = PACKET_HEADER; - lock.unlock(); + if(!_isConnected() || _isDisconnecting()) + return; + } - rawReceive(sizeof(Packet::Data), boost::bind(&Connection::packetHeaderReceiveHandler, this, _1, _2)); + rawReceive(sizeof(Packet::Data), boost::bind(&Connection::handleHeaderReceive, this, _1)); } -void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_HEADER) - return; +void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) { + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - if(length != sizeof(Packet::Data)) { - // TODO: Error - doDisconnect(); - return; + header = *reinterpret_cast<const Packet::Data*>(data.data()); } - header = *(const Packet::Data*)data; - if(header.length == 0) { ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId)))); enterReceiveLoop(); } else { - state = PACKET_DATA; - rawReceive(ntohs(header.length), boost::bind(&Connection::packetDataReceiveHandler, this, _1, _2)); + rawReceive(ntohs(header.length), boost::bind(&Connection::handleDataReceive, this, _1)); } } -void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { - if(state != PACKET_DATA) - return; +void Connection::handleDataReceive(const std::vector<boost::uint8_t> &data) { + { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - if(length != ntohs(header.length)) { - // TODO: Error - doDisconnect(); - return; + Packet packet(ntohs(header.requestId), data.data(), ntohs(header.length)); + ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, packet)); } - ThreadManager::get()->pushWork(boost::bind((void (boost::signal1<void, const Packet&>::*)(const Packet&))&boost::signal1<void, const Packet&>::operator(), &receiveSignal, Packet(ntohs(header.requestId), data, length))); - enterReceiveLoop(); } -void Connection::doReceive() { - if(!isConnected()) - return; - - boost::unique_lock<boost::mutex> lock(receiveLock); - - if(_receiveComplete()) - return; - - ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); +void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + if(error || (bytes_transferred+received) < length) { + Common::Logger::logf(Common::Logger::VERBOSE, "Read error: %s", error.message().c_str()); - if(ret < 0) { - lock.unlock(); - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; - - // TODO: Error + // TODO Error doDisconnect(); return; } - transR.transmitted += ret; - - if(_receiveComplete()) { - // Save data pointer, as transR.notify might start a new reception - uint8_t *data = transR.data; - transR.data = 0; + std::vector<boost::uint8_t> buffer; - lock.unlock(); + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); - transR.notify(data, transR.length); + if(state != CONNECTED || !receiving) + return; - delete [] data; - } - else { - lock.unlock(); + buffer.insert(buffer.end(), receiveBuffer.data(), receiveBuffer.data()+length); } - updateEvents(); -} - -bool Connection::rawReceive(unsigned long length, - const boost::function2<void,const void*,unsigned long> ¬ify) -{ - if(!isConnected()) - return false; - - boost::unique_lock<boost::mutex> lock(receiveLock); - if(!_receiveComplete()) - return false; - - transR.data = new uint8_t[length]; - transR.length = length; - transR.transmitted = 0; - transR.notify = notify; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - lock.unlock(); + receiving = false; + received = received + bytes_transferred - length; - updateEvents(); + if(received) + std::memmove(receiveBuffer.data(), receiveBuffer.data()+length, received); + } - return true; + notify(buffer); } -void Connection::doSend() { - if(!isConnected()) - return; +void Connection::rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - boost::unique_lock<boost::mutex> lock(sendLock); - while(!_sendQueueEmpty()) { - ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, - transS.front().length-transS.front().transmitted); - - if(ret < 0) { - lock.unlock(); + if(!_isConnected()) + return; - if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) - return; + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - // TODO: Error - doDisconnect(); + if(receiving) return; - } - transS.front().transmitted += ret; + receiving = true; - if(transS.front().transmitted == transS.front().length) { - delete [] transS.front().data; - transS.pop(); + if(length > received) { + boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer.data()+received, receiveBuffer.size()-received), boost::asio::transfer_at_least(length), + boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, + length, notify)); + + return; } } lock.unlock(); - updateEvents(); + handleRead(boost::system::error_code(), 0, length, notify); } -bool Connection::rawSend(const uint8_t *data, unsigned long length) { - if(!isConnected()) - return false; +void Connection::handleWrite(const boost::system::error_code& error, std::size_t) { + { + boost::unique_lock<boost::shared_mutex> lock(connectionLock); - Transmission trans = {length, 0, new uint8_t[length], boost::function2<void,const void*,unsigned long>()}; - std::memcpy(trans.data, data, length); + sending--; - sendLock.lock(); - transS.push(trans); - sendLock.unlock(); - - updateEvents(); + if(state == DISCONNECT && !sending) { + lock.unlock(); + doDisconnect(); + return; + } + } - return true; -} + if(error) { + Common::Logger::logf(Common::Logger::VERBOSE, "Write error: %s", error.message().c_str()); -void Connection::sendReceive(short events) { - if(events & POLLHUP || events & POLLERR) { + // TODO Error doDisconnect(); - return; } +} - switch(state) { - case CONNECT: - handshake(); - return; - case HANDSHAKE: - doHandshake(); - return; - case DISCONNECT: - if(!_sendQueueEmpty()) - break; +void Connection::rawSend(const uint8_t *data, std::size_t length) { + boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - bye(); - return; - case BYE: - doBye(); - return; - default: - break; - } + if(!_isConnected()) + return; - if(events & POLLIN) - doReceive(); + { + boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); - if(events & POLLOUT) - doSend(); + sending++; + boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } } bool Connection::send(const Packet &packet) { - stateLock.lock_shared(); - bool err = (!_isConnected() || _isConnecting() || _isDisconnecting()); - stateLock.unlock_shared(); - - if(err) - return false; + { + boost::shared_lock<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isConnecting() || _isDisconnecting()) + return false; + } - return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + return true; } void Connection::disconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - if(!_isConnected() || _isDisconnecting()) - return; + { + boost::lock_guard<boost::shared_mutex> lock(connectionLock); + if(!_isConnected() || _isDisconnecting()) + return; - state = DISCONNECT; + state = DISCONNECT; - lock.unlock(); + if(sending) + return; + } - updateEvents(); + doDisconnect(); } void Connection::doDisconnect() { - boost::unique_lock<boost::shared_mutex> lock(stateLock); - - if(_isConnected()) { - FdManager::get()->unregisterFd(sock); - - shutdown(sock, SHUT_RDWR); - close(sock); + boost::lock_guard<boost::shared_mutex> lock(connectionLock); - gnutls_deinit(session); - - ThreadManager::get()->pushWork(boost::bind((void (boost::signal0<void>::*)())&boost::signal0<void>::operator(), &disconnectedSignal)); - - state = DISCONNECTED; - } -} - -void Connection::updateEvents() { - receiveLock.lock(); - short events = (_receiveComplete() ? 0 : POLLIN); - receiveLock.unlock(); - - sendLock.lock(); - events |= (_sendQueueEmpty() ? 0 : POLLOUT); - sendLock.unlock(); - - stateLock.lock_shared(); - if(state == HANDSHAKE || state == BYE) - events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); - else if(state == CONNECT || state == DISCONNECT) - events |= POLLOUT; - - FdManager::get()->setFdEvents(sock, events); - stateLock.unlock_shared(); + if(_isConnected()) + socket.async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); } } |