diff options
Diffstat (limited to 'src/Net')
-rw-r--r-- | src/Net/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/Net/ClientConnection.cpp | 4 | ||||
-rw-r--r-- | src/Net/ClientConnection.h | 13 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 39 | ||||
-rw-r--r-- | src/Net/Connection.h | 50 | ||||
-rw-r--r-- | src/Net/Listener.cpp | 34 | ||||
-rw-r--r-- | src/Net/Listener.h | 8 | ||||
-rw-r--r-- | src/Net/Packet.cpp | 8 | ||||
-rw-r--r-- | src/Net/Packet.h | 40 | ||||
-rw-r--r-- | src/Net/export.h | 22 |
10 files changed, 134 insertions, 88 deletions
diff --git a/src/Net/CMakeLists.txt b/src/Net/CMakeLists.txt index ceaf6bd..7d38d7a 100644 --- a/src/Net/CMakeLists.txt +++ b/src/Net/CMakeLists.txt @@ -1,6 +1,8 @@ include_directories(${INCLUDES}) -add_library(Net STATIC +mad_library(Net + export.h + ClientConnection.cpp ClientConnection.h Connection.cpp Connection.h Listener.cpp Listener.h diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index d0c7aa4..2228b7a 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -31,7 +31,7 @@ void ClientConnection::handleConnect(const boost::system::error_code& error) { boost::lock_guard<boost::shared_mutex> lock(connectionLock); - socket.async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error)); + socket->async_handshake(boost::asio::ssl::stream_base::client, boost::bind(&ClientConnection::handleHandshake, this, boost::asio::placeholders::error)); } void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) throw(Core::Exception) { @@ -45,7 +45,7 @@ void ClientConnection::connect(const boost::asio::ip::tcp::endpoint &address) th peer = address; _setState(CONNECT); - socket.lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error)); + socket->lowest_layer().async_connect(address, boost::bind(&ClientConnection::handleConnect, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index b17c208..d29d6ae 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -20,27 +20,24 @@ #ifndef MAD_NET_CLIENTCONNECTION_H_ #define MAD_NET_CLIENTCONNECTION_H_ +#include "export.h" + #include "Connection.h" #include <Core/Exception.h> -#include <boost/utility/base_from_member.hpp> - namespace Mad { namespace Net { class IPAddress; -class ClientConnection : private boost::base_from_member<boost::asio::ssl::context>, public Connection { +class MAD_NET_EXPORT ClientConnection : public Connection { private: void handleConnect(const boost::system::error_code& error); public: - ClientConnection(Core::Application *application) - : boost::base_from_member<boost::asio::ssl::context>(boost::ref(application->getIOService()), boost::asio::ssl::context::sslv23), - Connection(application, member) - { - member.set_verify_mode(boost::asio::ssl::context::verify_none); + ClientConnection(Core::Application *application) : Connection(application) { + context.set_verify_mode(boost::asio::ssl::context::verify_none); } void connect(const boost::asio::ip::tcp::endpoint &address) throw(Core::Exception); diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 305e1de..036c3d8 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -61,7 +61,7 @@ void Connection::handleShutdown(const boost::system::error_code& error) { boost::lock_guard<boost::shared_mutex> lock(connectionLock); if(error) { - application->logf(Core::LoggerBase::VERBOSE, "Shutdown error: %s", error.message().c_str()); + application->logf(Core::LoggerBase::LOG_VERBOSE, "Shutdown error: %s", error.message().c_str()); } _setState(DISCONNECTED); @@ -77,14 +77,14 @@ void Connection::enterReceiveLoop() { return; } - rawReceive(sizeof(Packet::Data), boost::bind(&Connection::handleHeaderReceive, this, _1)); + rawReceive(sizeof(Packet::Header), boost::bind(&Connection::handleHeaderReceive, this, _1)); } -void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) { +void Connection::handleHeaderReceive(const boost::shared_array<boost::uint8_t> &data) { { boost::lock_guard<boost::shared_mutex> lock(connectionLock); - header = *reinterpret_cast<const Packet::Data*>(data.data()); + header = *reinterpret_cast<const Packet::Header*>(data.get()); } if(header.length == 0) { @@ -97,35 +97,34 @@ void Connection::handleHeaderReceive(const std::vector<boost::uint8_t> &data) { } } -void Connection::handleDataReceive(const std::vector<boost::uint8_t> &data) { +void Connection::handleDataReceive(const boost::shared_array<boost::uint8_t> &data) { { boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); - Packet packet(); - receiveSignal.emit(boost::shared_ptr<Packet>(new Packet(ntohs(header.requestId), data.data(), ntohs(header.length)))); + receiveSignal.emit(boost::shared_ptr<Packet>(new Packet(ntohs(header.requestId), data.get(), ntohs(header.length)))); } enterReceiveLoop(); } -void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { +void Connection::handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const boost::shared_array<boost::uint8_t>& > ¬ify) { if(error || (bytes_transferred+received) < length) { - application->logf(Core::LoggerBase::VERBOSE, "Read error: %s", error.message().c_str()); + application->logf(Core::LoggerBase::LOG_VERBOSE, "Read error: %s", error.message().c_str()); // TODO Error doDisconnect(); return; } - std::vector<boost::uint8_t> buffer; + boost::shared_array<boost::uint8_t> buffer(new boost::uint8_t[length]); { boost::shared_lock<boost::shared_mutex> lock(connectionLock); if(state != CONNECTED || !receiving) return; - - buffer.insert(buffer.end(), receiveBuffer.data(), receiveBuffer.data()+length); + + std::memcpy(buffer.get(), receiveBuffer->data(), length); } { @@ -135,13 +134,13 @@ void Connection::handleRead(const boost::system::error_code& error, std::size_t received = received + bytes_transferred - length; if(received) - std::memmove(receiveBuffer.data(), receiveBuffer.data()+length, received); + std::memmove(receiveBuffer->data(), receiveBuffer->data()+length, received); } notify(buffer); } -void Connection::rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify) { +void Connection::rawReceive(std::size_t length, const boost::function1<void, const boost::shared_array<boost::uint8_t>& > ¬ify) { boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); if(!_isConnected()) @@ -156,7 +155,7 @@ void Connection::rawReceive(std::size_t length, const boost::function1<void, con receiving = true; if(length > received) { - boost::asio::async_read(socket, boost::asio::buffer(receiveBuffer.data()+received, receiveBuffer.size()-received), boost::asio::transfer_at_least(length), + boost::asio::async_read(*socket, boost::asio::buffer(receiveBuffer->data()+received, receiveBuffer->size()-received), boost::asio::transfer_at_least(length), boost::bind(&Connection::handleRead, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred, length, notify)); @@ -183,14 +182,14 @@ void Connection::handleWrite(const boost::system::error_code& error, std::size_t } if(error) { - application->logf(Core::LoggerBase::VERBOSE, "Write error: %s", error.message().c_str()); + application->logf(Core::LoggerBase::LOG_VERBOSE, "Write error: %s", error.message().c_str()); // TODO Error doDisconnect(); } } -void Connection::rawSend(const uint8_t *data, std::size_t length) { +void Connection::rawSend(const boost::uint8_t *data, std::size_t length) { boost::upgrade_lock<boost::shared_mutex> lock(connectionLock); if(!_isConnected()) @@ -200,7 +199,7 @@ void Connection::rawSend(const uint8_t *data, std::size_t length) { boost::upgrade_to_unique_lock<boost::shared_mutex> upgradeLock(lock); sending++; - boost::asio::async_write(socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + boost::asio::async_write(*socket, Buffer(data, length), boost::bind(&Connection::handleWrite, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } } @@ -211,7 +210,7 @@ bool Connection::send(const Packet &packet) { return false; } - rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); + rawSend((const boost::uint8_t*)packet.getRawData(), packet.getRawDataLength()); return true; } @@ -233,7 +232,7 @@ void Connection::disconnect() { void Connection::doDisconnect() { boost::lock_guard<boost::shared_mutex> lock(connectionLock); - socket.async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); + socket->async_shutdown(boost::bind(&Connection::handleShutdown, this, boost::asio::placeholders::error)); } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 3070282..19ee826 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -20,10 +20,15 @@ #ifndef MAD_NET_CONNECTION_H_ #define MAD_NET_CONNECTION_H_ +#include "export.h" + #include "Packet.h" #include <Core/Signals.h> #include <Core/ThreadManager.h> +#include <boost/array.hpp> +#include <boost/shared_array.hpp> + #include <boost/asio.hpp> #include <boost/asio/ssl.hpp> @@ -35,7 +40,7 @@ namespace Net { class Listener; class ThreadManager; -class Connection : boost::noncopyable { +class MAD_NET_EXPORT Connection : boost::noncopyable { protected: friend class Listener; friend class ThreadManager; @@ -45,9 +50,9 @@ class Connection : boost::noncopyable { }; private: - class Buffer { + class MAD_NET_EXPORT Buffer { public: - Buffer(const uint8_t *data0, std::size_t length) : data(new std::vector<uint8_t>(data0, data0+length)), buffer(boost::asio::buffer(*data)) {} + Buffer(const boost::uint8_t *data0, std::size_t length) : data(new std::vector<boost::uint8_t>(data0, data0+length)), buffer(boost::asio::buffer(*data)) {} typedef boost::asio::const_buffer value_type; typedef const boost::asio::const_buffer* const_iterator; @@ -56,7 +61,7 @@ class Connection : boost::noncopyable { const boost::asio::const_buffer* end() const { return &buffer + 1; } private: - boost::shared_ptr<std::vector<uint8_t> > data; + boost::shared_ptr<std::vector<boost::uint8_t> > data; boost::asio::const_buffer buffer; }; @@ -66,10 +71,10 @@ class Connection : boost::noncopyable { State state; - std::vector<boost::uint8_t> receiveBuffer; + boost::scoped_ptr<boost::array<boost::uint8_t, 1024*1024> > receiveBuffer; std::size_t received; - Packet::Data header; + Packet::Header header; Core::Signals::Signal1<boost::shared_ptr<Packet> > receiveSignal; Core::Signals::Signal0 connectedSignal; @@ -77,24 +82,29 @@ class Connection : boost::noncopyable { bool receiving; unsigned long sending; - + + void _initSocket() { + socket.reset(new boost::asio::ssl::stream<boost::asio::ip::tcp::socket>(application->getIOService(), context)); + } + void enterReceiveLoop(); - void handleHeaderReceive(const std::vector<boost::uint8_t> &data); - void handleDataReceive(const std::vector<boost::uint8_t> &data); + void handleHeaderReceive(const boost::shared_array<boost::uint8_t> &data); + void handleDataReceive(const boost::shared_array<boost::uint8_t> &data); - void handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); + void handleRead(const boost::system::error_code& error, std::size_t bytes_transferred, std::size_t length, const boost::function1<void, const boost::shared_array<boost::uint8_t>& > ¬ify); void handleWrite(const boost::system::error_code& error, std::size_t); void handleShutdown(const boost::system::error_code& error); - void rawReceive(std::size_t length, const boost::function1<void, const std::vector<boost::uint8_t>& > ¬ify); - void rawSend(const uint8_t *data, std::size_t length); + void rawReceive(std::size_t length, const boost::function1<void, const boost::shared_array<boost::uint8_t>& > ¬ify); + void rawSend(const boost::uint8_t *data, std::size_t length); protected: boost::shared_mutex connectionLock; - - boost::asio::ssl::stream<boost::asio::ip::tcp::socket> socket; + + boost::asio::ssl::context context; + boost::scoped_ptr<boost::asio::ssl::stream<boost::asio::ip::tcp::socket> > socket; boost::asio::ip::tcp::endpoint peer; void handleHandshake(const boost::system::error_code& error); @@ -110,14 +120,20 @@ class Connection : boost::noncopyable { void _setState(State newState) { state = newState; + + if(_isConnected() && !socket.get()) + _initSocket(); + else if(!_isConnected() && socket.get()) + socket.reset(); + stateChanged.notify_all(); } void doDisconnect(); - Connection(Core::Application *application0, boost::asio::ssl::context &sslContext) : - application(application0), state(DISCONNECTED), receiveBuffer(1024*1024), receiveSignal(application), connectedSignal(application), - disconnectedSignal(application), socket(application->getIOService(), sslContext) {} + Connection(Core::Application *application0) : + application(application0), state(DISCONNECTED), receiveBuffer(new boost::array<boost::uint8_t, 1024*1024>), receiveSignal(application), connectedSignal(application), + disconnectedSignal(application), context(application->getIOService(), boost::asio::ssl::context::sslv23) {} public: virtual ~Connection(); diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 89ea399..6187a1e 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -26,6 +26,20 @@ namespace Mad { namespace Net { +void Listener::accept() { + boost::shared_ptr<Connection> con(new Connection(application)); + + con->context.set_options(boost::asio::ssl::context::default_workarounds + | boost::asio::ssl::context::no_sslv2 + | boost::asio::ssl::context::single_dh_use); + con->context.use_certificate_chain_file(x905CertFile); + con->context.use_private_key_file(x905KeyFile, boost::asio::ssl::context::pem); + + con->_initSocket(); + + acceptor.async_accept(con->socket->lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); +} + void Listener::handleAccept(const boost::system::error_code &error, boost::shared_ptr<Connection> con) { if(error) return; @@ -42,11 +56,10 @@ void Listener::handleAccept(const boost::system::error_code &error, boost::share connections.insert(std::make_pair(con, std::make_pair(con1, con2))); - con->socket.async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&Connection::handleHandshake, con, boost::asio::placeholders::error)); + con->socket->async_handshake(boost::asio::ssl::stream_base::server, boost::bind(&Connection::handleHandshake, con, boost::asio::placeholders::error)); } - con.reset(new Connection(application, sslContext)); - acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); + accept(); } void Listener::handleConnect(boost::shared_ptr<Connection> con) { @@ -76,19 +89,8 @@ void Listener::handleDisconnect(boost::shared_ptr<Connection> con) { Listener::Listener(Core::Application *application0, const std::string &x905CertFile0, const std::string &x905KeyFile0, const boost::asio::ip::tcp::endpoint &address0) throw(Core::Exception) : application(application0), x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0), -acceptor(application->getIOService(), address), sslContext(application->getIOService(), boost::asio::ssl::context::sslv23), -signal(application) -{ - sslContext.set_options(boost::asio::ssl::context::default_workarounds - | boost::asio::ssl::context::no_sslv2 - | boost::asio::ssl::context::single_dh_use); - sslContext.use_certificate_chain_file(x905CertFile0); - sslContext.use_private_key_file(x905KeyFile0, boost::asio::ssl::context::pem); - - - - boost::shared_ptr<Connection> con(new Connection(application, sslContext)); - acceptor.async_accept(con->socket.lowest_layer(), boost::bind(&Listener::handleAccept, this, boost::asio::placeholders::error, con)); +acceptor(application->getIOService(), address), signal(application) { + accept(); } Listener::~Listener() { diff --git a/src/Net/Listener.h b/src/Net/Listener.h index 2addafd..64572c0 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -20,6 +20,8 @@ #ifndef MAD_NET_LISTENER_H_ #define MAD_NET_LISTENER_H_ +#include "export.h" + #include <map> #include <string> @@ -29,7 +31,7 @@ namespace Mad { namespace Net { -class Listener : private boost::noncopyable { +class MAD_NET_EXPORT Listener : private boost::noncopyable { private: Core::Application *application; @@ -38,12 +40,12 @@ class Listener : private boost::noncopyable { std::string x905CertFile, x905KeyFile; boost::asio::ip::tcp::endpoint address; boost::asio::ip::tcp::acceptor acceptor; - boost::asio::ssl::context sslContext; std::map<boost::shared_ptr<Connection>, std::pair<Core::Signals::Connection, Core::Signals::Connection> > connections; Core::Signals::Signal1<boost::shared_ptr<Connection> > signal; - + + void accept(); void handleAccept(const boost::system::error_code &error, boost::shared_ptr<Connection> con); void handleConnect(boost::shared_ptr<Connection> con); diff --git a/src/Net/Packet.cpp b/src/Net/Packet.cpp index d2c2c70..3f32357 100644 --- a/src/Net/Packet.cpp +++ b/src/Net/Packet.cpp @@ -22,14 +22,14 @@ namespace Mad { namespace Net { -Packet::Packet(uint16_t requestId, const void *data, uint16_t length) { - rawData = (Data*)std::malloc(sizeof(Data)+length); +Packet::Packet(boost::uint16_t requestId, const void *data, boost::uint16_t length) { + rawData = reinterpret_cast<Header*>(std::malloc(sizeof(Header)+length)); rawData->requestId = htons(requestId); rawData->length = htons(length); if(length) - std::memcpy(rawData->data, data, length); + std::memcpy(reinterpret_cast<boost::uint8_t*>(rawData)+sizeof(Header), data, length); } Packet& Packet::operator=(const Packet &p) { @@ -38,7 +38,7 @@ Packet& Packet::operator=(const Packet &p) { std::free(rawData); - rawData = (Data*)std::malloc(p.getRawDataLength()); + rawData = reinterpret_cast<Header*>(std::malloc(p.getRawDataLength())); std::memcpy(rawData, p.rawData, p.getRawDataLength()); return *this; diff --git a/src/Net/Packet.h b/src/Net/Packet.h index dc62cb7..567c9c3 100644 --- a/src/Net/Packet.h +++ b/src/Net/Packet.h @@ -20,30 +20,36 @@ #ifndef MAD_NET_PACKET_H_ #define MAD_NET_PACKET_H_ +#include "export.h" + #include <cstdlib> #include <cstring> -#include <netinet/in.h> -#include <stdint.h> +#include <boost/cstdint.hpp> + +#ifdef _WIN32 +# include <winsock2.h> +#else +# include <netinet/in.h> +#endif namespace Mad { namespace Net { -class Packet { +class MAD_NET_EXPORT Packet { public: - struct Data { - uint16_t requestId; - uint16_t length; - uint8_t data[0]; + struct Header { + boost::uint16_t requestId; + boost::uint16_t length; }; protected: - Data *rawData; + Header *rawData; public: - Packet(uint16_t requestId, const void *data = 0, uint16_t length = 0); + Packet(boost::uint16_t requestId, const void *data = 0, boost::uint16_t length = 0); Packet(const Packet &p) { - rawData = (Data*)std::malloc(p.getRawDataLength()); + rawData = reinterpret_cast<Header*>(std::malloc(p.getRawDataLength())); std::memcpy(rawData, p.rawData, p.getRawDataLength()); } @@ -53,24 +59,24 @@ class Packet { Packet& operator=(const Packet &p); - uint16_t getRequestId() const { + boost::uint16_t getRequestId() const { return ntohs(rawData->requestId); } - uint16_t getLength() const { + boost::uint16_t getLength() const { return ntohs(rawData->length); } - const uint8_t* getData() const { - return rawData->data; + const boost::uint8_t* getData() const { + return reinterpret_cast<boost::uint8_t*>(rawData)+sizeof(Header); } - const Data* getRawData() const { - return rawData; + const boost::uint8_t* getRawData() const { + return reinterpret_cast<boost::uint8_t*>(rawData); } unsigned long getRawDataLength() const { - return sizeof(Data) + ntohs(rawData->length); + return sizeof(Header) + ntohs(rawData->length); } }; diff --git a/src/Net/export.h b/src/Net/export.h new file mode 100644 index 0000000..c0a5f6a --- /dev/null +++ b/src/Net/export.h @@ -0,0 +1,22 @@ +#ifndef MAD_NET_EXPORT +# ifdef _WIN32 +# ifdef MAD_NET_EXPORTS +# define MAD_NET_EXPORT _declspec(dllexport) +# else +# define MAD_NET_EXPORT _declspec(dllimport) +# endif +# else +# define MAD_NET_EXPORT +# endif + +# ifdef MAD_NET_EXPORTS +# ifndef MAD_CORE_EXPORTS +# define MAD_CORE_EXPORTS +# endif +# else +# undef MAD_CORE_EXPORTS +# endif + +#include <Core/export.h> + +#endif |