From 09b8df5200de1c8c20ea2856a8c6aa76b0811bd1 Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Fri, 11 Sep 2009 23:13:23 +0200 Subject: Connection: Allow setting a receive limit --- src/Net/Connection.cpp | 16 +++++++++++----- src/Net/Connection.h | 26 ++++++++++++++------------ 2 files changed, 25 insertions(+), 17 deletions(-) (limited to 'src/Net') diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index f1beb35..256bbfe 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -35,7 +35,7 @@ Connection::~Connection() { void Connection::handleHandshake(const boost::system::error_code& error) { if(error) { - application->logf("Error: %s", error.message().c_str()); + application->logf(Core::Logger::LOG_NETWORK, "Error: %s", error.message().c_str()); // TODO Error handling doDisconnect(); @@ -64,7 +64,7 @@ void Connection::handleShutdown(const boost::system::error_code& error) { boost::lock_guard lock(connectionLock); if(error) { - application->logf(Core::Logger::LOG_VERBOSE, "Shutdown error: %s", error.message().c_str()); + application->logf(Core::Logger::LOG_NETWORK, Core::Logger::LOG_VERBOSE, "Shutdown error: %s", error.message().c_str()); } _setState(DISCONNECTED); @@ -90,11 +90,17 @@ void Connection::handleHeaderReceive(const boost::shared_array & header = *reinterpret_cast(data.get()); } - if(header.length == 0) { + boost::uint32_t length = ntohl(header.length); + + if(length == 0) { receiveSignal.emit(boost::shared_ptr(new Packet(ntohs(header.requestId)))); enterReceiveLoop(); } + else if(length > receiveLimit) { + application->log(Core::Logger::LOG_NETWORK, Core::Logger::LOG_WARNING, "Packet size limit exceeded. Disconnecting."); + doDisconnect(); + } else { rawReceive(ntohl(header.length), boost::bind(&Connection::handleDataReceive, thisPtr.lock(), _1)); } @@ -115,7 +121,7 @@ void Connection::handleRead(const boost::system::error_code& error, std::size_t if(error == boost::system::errc::operation_canceled) return; - application->logf(Core::Logger::LOG_DEFAULT, "Read error: %s", error.message().c_str()); + application->logf(Core::Logger::LOG_NETWORK, "Read error: %s", error.message().c_str()); // TODO Error doDisconnect(); @@ -176,7 +182,7 @@ void Connection::rawReceive(std::size_t length, const boost::function1logf(Core::Logger::LOG_VERBOSE, "Write error: %s", error.message().c_str()); + application->logf(Core::Logger::LOG_NETWORK, Core::Logger::LOG_VERBOSE, "Write error: %s", error.message().c_str()); { boost::unique_lock lock(connectionLock); diff --git a/src/Net/Connection.h b/src/Net/Connection.h index add10b7..64b12c6 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -80,6 +80,8 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { Core::Signals::Signal0 connectedSignal; Core::Signals::Signal0 disconnectedSignal; + boost::uint32_t receiveLimit; + bool receiving; unsigned long sending; @@ -138,7 +140,8 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { application(application0), state(DISCONNECTED), dontStart(false), receiveBuffer(new boost::array), receiveSignal(application), connectedSignal(application), - disconnectedSignal(application), context(context0), socket(application->getIOService(), *context) {} + disconnectedSignal(application), receiveLimit(0xFFFF) /* 64K */, receiving(false), sending(0), + context(context0), socket(application->getIOService(), *context) {} static boost::shared_ptr create(Core::Application *application, boost::shared_ptr context) { boost::shared_ptr connection(new Connection(application, context)); @@ -180,17 +183,6 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { stateChanged.wait(lock); } - /*const gnutls_datum_t* getCertificate() const { - // TODO Thread-safeness - return gnutls_certificate_get_ours(session); - } - - const gnutls_datum_t* getPeerCertificate() const { - // TODO Thread-safeness - unsigned int n; - return gnutls_certificate_get_peers(session, &n); - }*/ - boost::asio::ip::tcp::endpoint getPeer() { boost::shared_lock lock(connectionLock); return peer; @@ -206,6 +198,16 @@ class MAD_NET_EXPORT Connection : boost::noncopyable { setStart(false); } + boost::uint32_t getReceiveLimit() { + boost::shared_lock lock(connectionLock); + return receiveLimit; + } + + void setReceiveLimit(boost::uint32_t limit) { + boost::lock_guard lock(connectionLock); + receiveLimit = limit; + } + void startReceive() { { boost::lock_guard lock(connectionLock); -- cgit v1.2.3