diff options
-rw-r--r-- | src/Common/Util.h | 6 | ||||
-rw-r--r-- | src/Core/ConnectionManager.cpp | 44 | ||||
-rw-r--r-- | src/Core/ConnectionManager.h | 2 | ||||
-rw-r--r-- | src/Net/Connection.cpp | 8 | ||||
-rw-r--r-- | src/Net/Packet.h | 15 |
5 files changed, 30 insertions, 45 deletions
diff --git a/src/Common/Util.h b/src/Common/Util.h index 19e1eeb..04b666c 100644 --- a/src/Common/Util.h +++ b/src/Common/Util.h @@ -17,8 +17,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef UTIL_H_ -#define UTIL_H_ +#ifndef MAD_COMMON_UTIL_H_ +#define MAD_COMMON_UTIL_H_ #include <string> #include <locale> @@ -59,4 +59,4 @@ class Util { } } -#endif /* UTIL_H_ */ +#endif /* MAD_COMMON_UTIL_H_ */ diff --git a/src/Core/ConnectionManager.cpp b/src/Core/ConnectionManager.cpp index a537539..62eed53 100644 --- a/src/Core/ConnectionManager.cpp +++ b/src/Core/ConnectionManager.cpp @@ -62,7 +62,7 @@ ConnectionManager::ConnectionManager(const ConfigManager& configManager) : reque listeners.push_back(new Net::Listener(configManager.getX509CertFile(), configManager.getX509KeyFile())); } catch(Net::Exception &e) { - // TODO: Log error + // TODO Log error } } else { @@ -71,10 +71,11 @@ ConnectionManager::ConnectionManager(const ConfigManager& configManager) : reque listeners.push_back(new Net::Listener(configManager.getX509CertFile(), configManager.getX509KeyFile(), *address)); } catch(Net::Exception &e) { - // TODO: Log error + // TODO Log error } } } + refreshPollfds(); } @@ -86,10 +87,8 @@ ConnectionManager::~ConnectionManager() { delete *con; } -void ConnectionManager::run() { - // TODO: Logging - - for(std::list<Net::ServerConnection*>::iterator con = daemonConnections.begin(); con != daemonConnections.end();) { +void ConnectionManager::handleConnections(std::list<Net::ServerConnection*>& connections) { + for(std::list<Net::ServerConnection*>::iterator con = connections.begin(); con != connections.end();) { if((*con)->isConnected()) { std::map<int,const short*>::iterator events = pollfdMap.find((*con)->getSocket()); @@ -103,40 +102,23 @@ void ConnectionManager::run() { else { requestManager.unregisterConnection(*con); delete *con; - daemonConnections.erase(con++); + connections.erase(con++); } } +} - for(std::list<Net::ServerConnection*>::iterator con = clientConnections.begin(); con != clientConnections.end();) { - if((*con)->isConnected()) { - std::map<int,const short*>::iterator events = pollfdMap.find((*con)->getSocket()); - - if(events != pollfdMap.end()) - (*con)->sendReceive(*events->second); - else - (*con)->sendReceive(); +void ConnectionManager::run() { + // TODO Logging - ++con; - } - else { - requestManager.unregisterConnection(*con); - delete *con; - clientConnections.erase(con++); - } - } + handleConnections(daemonConnections); + handleConnections(clientConnections); for(std::list<Net::Listener*>::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) { Net::ServerConnection *con; while((con = (*listener)->getConnection(pollfdMap)) != 0) { - if(con->isDaemonConnection()) { - daemonConnections.push_back(con); - requestManager.registerConnection(con); - } - else { - clientConnections.push_back(con); - requestManager.registerConnection(con); - } + (con->isDaemonConnection() ? daemonConnections : clientConnections).push_back(con); + requestManager.registerConnection(con); } } diff --git a/src/Core/ConnectionManager.h b/src/Core/ConnectionManager.h index 7429a44..6161c3f 100644 --- a/src/Core/ConnectionManager.h +++ b/src/Core/ConnectionManager.h @@ -57,6 +57,8 @@ class ConnectionManager { void refreshPollfds(); + void handleConnections(std::list<Net::ServerConnection*>& connections); + public: ConnectionManager(const ConfigManager& configManager); virtual ~ConnectionManager(); diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index ac3121d..6a30d11 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -77,13 +77,13 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng header = *reinterpret_cast<const Packet::Data*>(data); if(header.length == 0) { - signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId)); + signal(this, Packet((Packet::Type)ntohs(header.type), ntohs(header.requestId))); enterReceiveLoop(); } else { state = PACKET_DATA; - rawReceive(header.length, sigc::mem_fun(this, &Connection::packetDataReceiveHandler)); + rawReceive(ntohs(header.length), sigc::mem_fun(this, &Connection::packetDataReceiveHandler)); } } @@ -91,13 +91,13 @@ void Connection::packetDataReceiveHandler(const void *data, unsigned long length if(state != PACKET_DATA) return; - if(length != header.length) { + if(length != ntohs(header.length)) { // TODO: Error doDisconnect(); return; } - signal(this, Packet(static_cast<Packet::Type>(header.type), header.requestId, data, length)); + signal(this, Packet((Packet::Type)ntohs(header.type), ntohs(header.requestId), data, length)); enterReceiveLoop(); } diff --git a/src/Net/Packet.h b/src/Net/Packet.h index cd734a4..08bc9be 100644 --- a/src/Net/Packet.h +++ b/src/Net/Packet.h @@ -22,6 +22,7 @@ #include <cstdlib> #include <cstring> +#include <netinet/in.h> namespace Mad { namespace Net { @@ -48,10 +49,10 @@ class Packet { Packet(Type type, unsigned short requestId, const void *data = NULL, unsigned short length = 0) { rawData = (Data*)std::malloc(sizeof(Data)+length); - rawData->type = type; - rawData->requestId = requestId; + rawData->type = htons(type); + rawData->requestId = htons(requestId); rawData->reserved = 0; - rawData->length = length; + rawData->length = htons(length); if(length) std::memcpy(rawData->data, data, length); @@ -79,15 +80,15 @@ class Packet { } Type getType() const { - return (Type)rawData->type; + return (Type)ntohs(rawData->type); } unsigned short getRequestId() const { - return rawData->requestId; + return ntohs(rawData->requestId); } unsigned short getLength() const { - return rawData->length; + return ntohs(rawData->length); } const unsigned char* getData() const { @@ -99,7 +100,7 @@ class Packet { } unsigned long getRawDataLength() const { - return sizeof(Data) + rawData->length; + return sizeof(Data) + ntohs(rawData->length); } }; |