From c8d469cc3de8ef2fb95f7b47355ebf5318a4c22f Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Fri, 15 May 2009 17:30:40 +0200 Subject: Einfache (ziemlich kaputte) Multithreaded IO --- src/Common/ClientConnection.cpp | 6 +- src/Common/ClientConnection.h | 2 +- src/Common/Connection.cpp | 6 +- src/Common/Connection.h | 4 +- src/Daemon/Requests/IdentifyRequest.cpp | 2 + src/Net/ClientConnection.cpp | 22 ++++- src/Net/Connection.cpp | 170 ++++++++++++++++++++++++-------- src/Net/Connection.h | 104 ++++++++++++++----- src/Net/FdManager.cpp | 105 ++++++++++++++++---- src/Net/FdManager.h | 25 ++++- src/Net/Listener.cpp | 66 +++++++------ src/Net/Listener.h | 14 ++- src/Net/ServerConnection.cpp | 14 ++- src/Net/ThreadManager.cpp | 58 +---------- src/Net/ThreadManager.h | 23 +++-- src/Server/ConnectionManager.cpp | 76 +++++++------- src/Server/ConnectionManager.h | 12 ++- src/mad-server.cpp | 4 +- src/mad.cpp | 9 +- src/madc.cpp | 16 +-- 20 files changed, 463 insertions(+), 275 deletions(-) (limited to 'src') diff --git a/src/Common/ClientConnection.cpp b/src/Common/ClientConnection.cpp index 0e080fd..7b33fb2 100644 --- a/src/Common/ClientConnection.cpp +++ b/src/Common/ClientConnection.cpp @@ -20,6 +20,8 @@ #include "ClientConnection.h" #include +#include "Logger.h" + namespace Mad { namespace Common { @@ -27,8 +29,8 @@ ClientConnection::ClientConnection() : connection(new Net::ClientConnection) { connection->signalReceive().connect(sigc::mem_fun(this, &ClientConnection::receive)); } -void ClientConnection::send(const Net::Packet &packet) { - connection->send(packet); +bool ClientConnection::send(const Net::Packet &packet) { + return connection->send(packet); } void ClientConnection::connect(const Net::IPAddress &address, bool daemon) throw(Net::Exception) { diff --git a/src/Common/ClientConnection.h b/src/Common/ClientConnection.h index 28a7016..09ca4db 100644 --- a/src/Common/ClientConnection.h +++ b/src/Common/ClientConnection.h @@ -37,7 +37,7 @@ class ClientConnection : public Connection { Net::ClientConnection *connection; protected: - virtual void send(const Net::Packet &packet); + virtual bool send(const Net::Packet &packet); public: ClientConnection(); diff --git a/src/Common/Connection.cpp b/src/Common/Connection.cpp index cde3fc2..510731f 100644 --- a/src/Common/Connection.cpp +++ b/src/Common/Connection.cpp @@ -26,13 +26,11 @@ namespace Mad { namespace Common { void Connection::receive(const Net::Packet &packet) { - // receive() will be called by FdManager (main thread) - // -> let the ThreadManager call the handler in the worker thread signal(XmlPacket(packet), packet.getRequestId()); } -void Connection::sendPacket(const XmlPacket &packet, uint16_t requestId) { - send(packet.encode(requestId)); +bool Connection::sendPacket(const XmlPacket &packet, uint16_t requestId) { + return send(packet.encode(requestId)); } } diff --git a/src/Common/Connection.h b/src/Common/Connection.h index 7bcb92b..860c044 100644 --- a/src/Common/Connection.h +++ b/src/Common/Connection.h @@ -50,12 +50,12 @@ class Connection { void receive(const Net::Packet &packet); - virtual void send(const Net::Packet &packet) = 0; + virtual bool send(const Net::Packet &packet) = 0; public: virtual ~Connection() {} - void sendPacket(const XmlPacket &packet, uint16_t requestId); + bool sendPacket(const XmlPacket &packet, uint16_t requestId); sigc::signal signalReceive() const { return signal; diff --git a/src/Daemon/Requests/IdentifyRequest.cpp b/src/Daemon/Requests/IdentifyRequest.cpp index ba0adef..dd035bf 100644 --- a/src/Daemon/Requests/IdentifyRequest.cpp +++ b/src/Daemon/Requests/IdentifyRequest.cpp @@ -19,6 +19,8 @@ #include "IdentifyRequest.h" +#include + namespace Mad { namespace Daemon { namespace Requests { diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 0162a86..e4de735 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -29,6 +29,7 @@ namespace Mad { namespace Net { +// TODO Error handling void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { if(length != sizeof(ConnectionHeader)) // Error... disconnect @@ -55,15 +56,21 @@ void ClientConnection::connectionHeader() { } void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exception) { + gl_rwlock_wrlock(stateLock); + daemon = daemon0; - if(isConnected()) + if(_isConnected()) { + gl_rwlock_unlock(stateLock); return; // TODO Error + } sock = socket(PF_INET, SOCK_STREAM, 0); - if(sock < 0) + if(sock < 0) { + gl_rwlock_unlock(stateLock); throw Exception("socket()", Exception::INTERNAL_ERRNO, errno); + } if(peer) delete peer; @@ -73,6 +80,8 @@ void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exc close(sock); delete peer; peer = 0; + + gl_rwlock_unlock(stateLock); throw Exception("connect()", Exception::INTERNAL_ERRNO, errno); } @@ -82,6 +91,7 @@ void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exc if(flags < 0) { close(sock); + gl_rwlock_unlock(stateLock); throw Exception("fcntl()", Exception::INTERNAL_ERRNO, errno); } @@ -96,9 +106,13 @@ void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(Exc gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - FdManager::get()->registerFd(sock, sigc::mem_fun(this, &Connection::sendReceive)); + FdManager::get()->registerFd(sock, sigc::mem_fun(this, &ClientConnection::sendReceive)); + + state = CONNECT; + + gl_rwlock_unlock(stateLock); - handshake(); + updateEvents(); } } diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index 0984f0a..2ccfddb 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -20,66 +20,92 @@ #include "Connection.h" #include "FdManager.h" #include "IPAddress.h" +#include "ThreadManager.h" + #include #include +#include + namespace Mad { namespace Net { + +Connection::StaticInit Connection::staticInit; + + Connection::~Connection() { - if(isConnected()) + if(_isConnected()) doDisconnect(); if(transR.data) delete [] transR.data; - while(!sendQueueEmpty()) { + while(!_sendQueueEmpty()) { delete [] transS.front().data; transS.pop(); } gnutls_certificate_free_credentials(x509_cred); + gl_rwlock_destroy(stateLock); + gl_lock_destroy(sendLock); + gl_lock_destroy(receiveLock); + if(peer) delete peer; } void Connection::handshake() { - if(isConnected()) + gl_rwlock_wrlock(stateLock); + if(state != CONNECT) { + gl_rwlock_unlock(stateLock); return; + } state = HANDSHAKE; + gl_rwlock_unlock(stateLock); doHandshake(); } void Connection::bye() { - if(state != DISCONNECT) + gl_rwlock_wrlock(stateLock); + if(state != DISCONNECT) { + gl_rwlock_unlock(stateLock); return; + } state = BYE; + gl_rwlock_unlock(stateLock); doBye(); } void Connection::doHandshake() { - if(state != HANDSHAKE) + gl_rwlock_rdlock(stateLock); + if(state != HANDSHAKE) { + gl_rwlock_unlock(stateLock); return; + } int ret = gnutls_handshake(session); if(ret < 0) { + gl_rwlock_unlock(stateLock); + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { updateEvents(); return; } - // TODO: Error doDisconnect(); return; } state = CONNECTION_HEADER; + gl_rwlock_unlock(stateLock); + connectionHeader(); } @@ -102,13 +128,21 @@ void Connection::doBye() { doDisconnect(); } -bool Connection::enterReceiveLoop() { - if(!isConnected() || isDisconnecting()) - return false; +void Connection::enterReceiveLoop() { + gl_rwlock_wrlock(stateLock); + + if(!_isConnected() || _isDisconnecting()) { + gl_rwlock_unlock(stateLock); + return; + } + + if(_isConnecting()) + ThreadManager::get()->pushWork(sigc::mem_fun(connectedSignal, &sigc::signal::emit)); state = PACKET_HEADER; + gl_rwlock_unlock(stateLock); - return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); + rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); } void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { @@ -124,7 +158,7 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng header = *(const Packet::Data*)data; if(header.length == 0) { - signal(Packet(ntohs(header.requestId))); + ThreadManager::get()->pushWork(sigc::bind(sigc::mem_fun(receiveSignal, &sigc::signal::emit), Packet(ntohs(header.requestId)))); enterReceiveLoop(); } @@ -144,7 +178,7 @@ void Connection::packetDataReceiveHandler(const void *data, unsigned long length return; } - signal(Packet(ntohs(header.requestId), data, length)); + ThreadManager::get()->pushWork(sigc::bind(sigc::mem_fun(receiveSignal, &sigc::signal::emit), Packet(ntohs(header.requestId), data, length))); enterReceiveLoop(); } @@ -153,12 +187,17 @@ void Connection::doReceive() { if(!isConnected()) return; - if(receiveComplete()) + gl_lock_lock(receiveLock); + + if(_receiveComplete()) { + gl_lock_unlock(receiveLock); return; + } ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); if(ret < 0) { + gl_lock_unlock(receiveLock); if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; @@ -169,15 +208,20 @@ void Connection::doReceive() { transR.transmitted += ret; - if(receiveComplete()) { + if(_receiveComplete()) { // Save data pointer, as transR.notify might start a new reception uint8_t *data = transR.data; transR.data = 0; + gl_lock_unlock(receiveLock); + transR.notify(data, transR.length); delete [] data; } + else { + gl_lock_unlock(receiveLock); + } updateEvents(); } @@ -188,14 +232,19 @@ bool Connection::rawReceive(unsigned long length, if(!isConnected()) return false; - if(!receiveComplete()) + gl_lock_lock(receiveLock); + if(!_receiveComplete()) { + gl_lock_unlock(receiveLock); return false; + } transR.data = new uint8_t[length]; transR.length = length; transR.transmitted = 0; transR.notify = notify; + gl_lock_unlock(receiveLock); + updateEvents(); return true; @@ -205,11 +254,14 @@ void Connection::doSend() { if(!isConnected()) return; - while(!sendQueueEmpty()) { + gl_lock_lock(sendLock); + while(!_sendQueueEmpty()) { ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, transS.front().length-transS.front().transmitted); if(ret < 0) { + gl_lock_unlock(sendLock); + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; @@ -226,6 +278,8 @@ void Connection::doSend() { } } + gl_lock_unlock(sendLock); + updateEvents(); } @@ -235,7 +289,10 @@ bool Connection::rawSend(const uint8_t *data, unsigned long length) { Transmission trans = {length, 0, new uint8_t[length], sigc::slot()}; std::memcpy(trans.data, data, length); + + gl_lock_lock(sendLock); transS.push(trans); + gl_lock_unlock(sendLock); updateEvents(); @@ -248,14 +305,24 @@ void Connection::sendReceive(short events) { return; } - if(state == HANDSHAKE) { - doHandshake(); - return; - } + switch(state) { + case CONNECT: + handshake(); + return; + case HANDSHAKE: + doHandshake(); + return; + case DISCONNECT: + if(!_sendQueueEmpty()) + break; - if(state == BYE) { - doBye(); - return; + bye(); + return; + case BYE: + doBye(); + return; + default: + break; } if(events & POLLIN) @@ -263,48 +330,69 @@ void Connection::sendReceive(short events) { if(events & POLLOUT) doSend(); - - if(state == DISCONNECT && sendQueueEmpty()) - bye(); } bool Connection::send(const Packet &packet) { - if(!isConnected() || isConnecting() || isDisconnecting()) + gl_rwlock_rdlock(stateLock); + bool err = (!_isConnected() || _isConnecting() || _isDisconnecting()); + gl_rwlock_unlock(stateLock); + + if(err) return false; return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); } void Connection::disconnect() { - if(isConnected() && !isDisconnecting()) { - state = DISCONNECT; - - if(sendQueueEmpty()) - bye(); + gl_rwlock_wrlock(stateLock); + if(!_isConnected() || _isDisconnecting()) { + gl_rwlock_unlock(stateLock); + return; } + + state = DISCONNECT; + + gl_rwlock_unlock(stateLock); + + updateEvents(); } void Connection::doDisconnect() { - if(!isConnected()) - return; + gl_rwlock_wrlock(stateLock); + + if(_isConnected()) { + FdManager::get()->unregisterFd(sock); - FdManager::get()->unregisterFd(sock); + shutdown(sock, SHUT_RDWR); + close(sock); - shutdown(sock, SHUT_RDWR); - close(sock); + gnutls_deinit(session); - gnutls_deinit(session); + ThreadManager::get()->pushWork(sigc::mem_fun(disconnectedSignal, &sigc::signal::emit)); - state = DISCONNECTED; + state = DISCONNECTED; + } + + gl_rwlock_unlock(stateLock); } -void Connection::updateEvents() const { - short events = (receiveComplete() ? 0 : POLLIN) | (sendQueueEmpty() ? 0 : POLLOUT); +void Connection::updateEvents() { + gl_lock_lock(receiveLock); + short events = (_receiveComplete() ? 0 : POLLIN); + gl_lock_unlock(receiveLock); + + gl_lock_lock(sendLock); + events |= (_sendQueueEmpty() ? 0 : POLLOUT); + gl_lock_unlock(sendLock); + gl_rwlock_rdlock(stateLock); if(state == HANDSHAKE || state == BYE) events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); + else if(state == CONNECT || state == DISCONNECT) + events |= POLLOUT; FdManager::get()->setFdEvents(sock, events); + gl_rwlock_unlock(stateLock); } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 1c15a95..1176f92 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -20,6 +20,8 @@ #ifndef MAD_NET_CONNECTION_H_ #define MAD_NET_CONNECTION_H_ +#include + #include "Packet.h" #include @@ -28,6 +30,10 @@ #include #include +#include "glthread/lock.h" + +#include + namespace Mad { namespace Net { @@ -36,6 +42,18 @@ class Packet; class Connection { private: + class StaticInit { + public: + StaticInit() { + gnutls_global_init(); + } + + ~StaticInit() { + gnutls_global_deinit(); + } + }; + static StaticInit staticInit; + struct Transmission { unsigned long length; unsigned long transmitted; @@ -45,16 +63,17 @@ class Connection { sigc::slot notify; }; - enum State { - DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE - } state; - + gl_lock_t receiveLock; Transmission transR; + + gl_lock_t sendLock; std::queue transS; Packet::Data header; - sigc::signal signal; + sigc::signal receiveSignal; + sigc::signal connectedSignal; + sigc::signal disconnectedSignal; void doHandshake(); @@ -68,13 +87,13 @@ class Connection { void doDisconnect(); - bool receiveComplete() const { + bool _receiveComplete() const { return (transR.length == transR.transmitted); } - void bye(); + bool _sendQueueEmpty() const {return transS.empty();} - void updateEvents() const; + void bye(); // Prevent shallow copy Connection(const Connection &o); @@ -93,6 +112,12 @@ class Connection { uint8_t protVerMax; }; + gl_rwlock_t stateLock; + + enum State { + DISCONNECTED, CONNECT, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA, DISCONNECT, BYE + } state; + int sock; gnutls_session_t session; gnutls_certificate_credentials_t x509_cred; @@ -106,55 +131,80 @@ class Connection { bool rawReceive(unsigned long length, const sigc::slot ¬ify); bool rawSend(const uint8_t *data, unsigned long length); - bool enterReceiveLoop(); + void enterReceiveLoop(); + + void sendReceive(short events); + + bool _isConnected() const {return (state != DISCONNECTED);} + bool _isConnecting() const { + return (state == CONNECT || state == HANDSHAKE || state == CONNECTION_HEADER); + } + + bool _isDisconnecting() const { + return (state == DISCONNECT || state == BYE); + } + + void updateEvents(); public: Connection() : state(DISCONNECTED), peer(0) { transR.length = transR.transmitted = 0; transR.data = 0; + gl_rwlock_init(stateLock); + gl_lock_init(sendLock); + gl_lock_init(receiveLock); + gnutls_certificate_allocate_credentials(&x509_cred); } virtual ~Connection(); - bool isConnected() const {return (state != DISCONNECTED);} - bool isConnecting() const { - return (state == HANDSHAKE || state == CONNECTION_HEADER); + bool isConnected() { + gl_rwlock_rdlock(stateLock); + bool ret = _isConnected(); + gl_rwlock_unlock(stateLock); + + return ret; } - bool isDisconnecting() const { - return (state == DISCONNECT || state == BYE); + bool isConnecting() { + gl_rwlock_rdlock(stateLock); + bool ret = _isConnecting(); + gl_rwlock_unlock(stateLock); + + return ret; } + /*bool isDisconnecting() { + gl_rwlock_rdlock(stateLock); + bool ret = (state == DISCONNECT || state == BYE); + gl_rwlock_unlock(stateLock); + + return ret; + }*/ + const gnutls_datum_t* getCertificate() const { + // TODO Thread-safeness return gnutls_certificate_get_ours(session); } const gnutls_datum_t* getPeerCertificate() const { + // TODO Thread-safeness unsigned int n; return gnutls_certificate_get_peers(session, &n); } + // TODO Thread-safeness const IPAddress* getPeer() const {return peer;} void disconnect(); bool send(const Packet &packet); - void sendReceive(short events); - - bool sendQueueEmpty() const {return transS.empty();} - - sigc::signal signalReceive() const {return signal;} - - static void init() { - gnutls_global_init(); - } - - static void deinit() { - gnutls_global_deinit(); - } + sigc::signal signalReceive() const {return receiveSignal;} + sigc::signal signalConnected() const {return connectedSignal;} + sigc::signal signalDisconnected() const {return disconnectedSignal;} }; } diff --git a/src/Net/FdManager.cpp b/src/Net/FdManager.cpp index ffa4d8b..499ad62 100644 --- a/src/Net/FdManager.cpp +++ b/src/Net/FdManager.cpp @@ -18,11 +18,14 @@ */ #include "FdManager.h" +#include "ThreadManager.h" + #include #include #include #include +#include namespace Mad { @@ -31,7 +34,11 @@ namespace Net { FdManager FdManager::fdManager; -FdManager::FdManager() { +FdManager::FdManager() : running(false) { + gl_rwlock_init(handlerLock); + gl_rwlock_init(eventLock); + gl_rwlock_init(runLock); + pipe(interruptPipe); int flags = fcntl(interruptPipe[0], F_GETFL, 0); @@ -48,43 +55,73 @@ FdManager::~FdManager() { close(interruptPipe[0]); close(interruptPipe[1]); + + gl_rwlock_destroy(runLock); + gl_rwlock_destroy(eventLock); + gl_rwlock_destroy(handlerLock); } bool FdManager::registerFd(int fd, const sigc::slot &handler, short events) { struct pollfd pollfd = {fd, events, 0}; + gl_rwlock_wrlock(handlerLock); + gl_rwlock_wrlock(eventLock); pollfds.insert(std::make_pair(fd, pollfd)); - return handlers.insert(std::make_pair(fd, handler)).second; + bool ret = handlers.insert(std::make_pair(fd, handler)).second; + gl_rwlock_unlock(eventLock); + gl_rwlock_unlock(handlerLock); + + interrupt(); + + return ret; } bool FdManager::unregisterFd(int fd) { + gl_rwlock_wrlock(handlerLock); + gl_rwlock_wrlock(eventLock); pollfds.erase(fd); - return handlers.erase(fd); + bool ret = handlers.erase(fd); + gl_rwlock_unlock(eventLock); + gl_rwlock_unlock(handlerLock); + + interrupt(); + + return ret; } bool FdManager::setFdEvents(int fd, short events) { + gl_rwlock_wrlock(eventLock); std::map::iterator pollfd = pollfds.find(fd); - if(pollfd == pollfds.end()) + if(pollfd == pollfds.end()) { + gl_rwlock_unlock(eventLock); return false; + } if(pollfd->second.events != events) { pollfd->second.events = events; interrupt(); } + gl_rwlock_unlock(eventLock); + return true; } -short FdManager::getFdEvents(int fd) const { +short FdManager::getFdEvents(int fd) { + gl_rwlock_rdlock(eventLock); + std::map::const_iterator pollfd = pollfds.find(fd); if(pollfd == pollfds.end()) return -1; - return pollfd->second.events; + short ret = pollfd->second.events; + gl_rwlock_unlock(eventLock); + + return ret; } void FdManager::readInterrupt() { @@ -99,27 +136,57 @@ void FdManager::interrupt() { write(interruptPipe[1], &buf, sizeof(buf)); } -void FdManager::run() { - readInterrupt(); +void FdManager::ioThread() { + gl_rwlock_wrlock(runLock); + running = true; + gl_rwlock_unlock(runLock); - size_t count = pollfds.size(); - struct pollfd *fdarray = new struct pollfd[count]; + gl_rwlock_rdlock(runLock); + while(running) { + gl_rwlock_unlock(runLock); - std::map::iterator pollfd = pollfds.begin(); + gl_rwlock_rdlock(handlerLock); + gl_rwlock_rdlock(eventLock); + readInterrupt(); - for(size_t n = 0; n < count; ++n) { - fdarray[n] = pollfd->second; - ++pollfd; - } + size_t count = pollfds.size(); + struct pollfd *fdarray = new struct pollfd[count]; + + std::map::iterator pollfd = pollfds.begin(); - if(poll(fdarray, count, -1) > 0) { for(size_t n = 0; n < count; ++n) { - if(fdarray[n].revents) - handlers[fdarray[n].fd](fdarray[n].revents); + fdarray[n] = pollfd->second; + ++pollfd; + } + + gl_rwlock_unlock(eventLock); + gl_rwlock_unlock(handlerLock); + + if(poll(fdarray, count, -1) > 0) { + gl_rwlock_rdlock(handlerLock); + + std::queue > calls; + + for(size_t n = 0; n < count; ++n) { + if(fdarray[n].revents) + calls.push(sigc::bind(handlers[fdarray[n].fd], fdarray[n].revents)); + } + + gl_rwlock_unlock(handlerLock); + + while(!calls.empty()) { + calls.front()(); + calls.pop(); + } + } + + delete [] fdarray; + + gl_rwlock_rdlock(runLock); } - delete [] fdarray; + gl_rwlock_unlock(runLock); } } diff --git a/src/Net/FdManager.h b/src/Net/FdManager.h index 2e6fc66..8c2ec12 100644 --- a/src/Net/FdManager.h +++ b/src/Net/FdManager.h @@ -20,17 +20,29 @@ #ifndef MAD_NET_FDMANAGER_H_ #define MAD_NET_FDMANAGER_H_ +#include + #include #include #include +#include "glthread/lock.h" +#include "glthread/cond.h" + namespace Mad { namespace Net { +class ThreadManager; + class FdManager { private: + friend class ThreadManager; + static FdManager fdManager; + gl_rwlock_t runLock, handlerLock, eventLock; + bool running; + std::map pollfds; std::map > handlers; @@ -41,6 +53,15 @@ class FdManager { FdManager(); + void ioThread(); + void stopIOThread() { + gl_rwlock_wrlock(runLock); + running = false; + gl_rwlock_unlock(runLock); + + interrupt(); + } + public: virtual ~FdManager(); @@ -50,9 +71,7 @@ class FdManager { bool unregisterFd(int fd); bool setFdEvents(int fd, short events); - short getFdEvents(int fd) const; - - void run(); + short getFdEvents(int fd); }; } diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 95147fa..5bcc353 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -21,10 +21,14 @@ #include "FdManager.h" #include "ServerConnection.h" +#include + #include #include #include +#include + namespace Mad { namespace Net { @@ -35,12 +39,42 @@ void Listener::acceptHandler(int) { while((sd = accept(sock, (struct sockaddr*)&sa, &addrlen)) >= 0) { - connections.push_back(new ServerConnection(sd, IPAddress(sa), dh_params, x905CertFile, x905KeyFile)); + ServerConnection *con = new ServerConnection(sd, IPAddress(sa), dh_params, x905CertFile, x905KeyFile); + sigc::connection con1 = con->signalConnected().connect(sigc::bind(sigc::mem_fun(this, &Listener::connectHandler), con)); + sigc::connection con2 = con->signalDisconnected().connect(sigc::bind(sigc::mem_fun(this, &Listener::disconnectHandler), con)); + + connections.insert(std::make_pair(con, std::make_pair(con1, con2))); addrlen = sizeof(sa); } } + +void Listener::connectHandler(ServerConnection *con) { + std::map >::iterator it = connections.find(con); + + if(it == connections.end()) + return; + + // Disconnect signal handlers + it->second.first.disconnect(); + it->second.second.disconnect(); + connections.erase(it); + + signal(con); +} + +void Listener::disconnectHandler(ServerConnection *con) { + std::map >::iterator it = connections.find(con); + + if(it == connections.end()) + return; + + delete it->first; + connections.erase(it); +} + + Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0) throw(Exception) : x905CertFile(x905CertFile0), x905KeyFile(x905KeyFile0), address(address0) { gnutls_dh_params_init(&dh_params); @@ -82,9 +116,9 @@ Listener::Listener(const std::string &x905CertFile0, const std::string &x905KeyF } Listener::~Listener() { - for(std::list::iterator con = connections.begin(); con != connections.end(); ++con) { - (*con)->disconnect(); - delete *con; + for(std::map >::iterator con = connections.begin(); con != connections.end(); ++con) { + con->first->disconnect(); + delete con->first; } shutdown(sock, SHUT_RDWR); @@ -93,29 +127,5 @@ Listener::~Listener() { gnutls_dh_params_deinit(dh_params); } -ServerConnection* Listener::getConnection() { - // TODO: Logging - - for(std::list::iterator con = connections.begin(); con != connections.end();) { - if(!(*con)->isConnected()) { - delete *con; - connections.erase(con++); // Erase unincremented iterator - - continue; - } - - if(!(*con)->isConnecting()) { - ServerConnection *connection = *con; - connections.erase(con); - - return connection; - } - - ++con; - } - - return 0; -} - } } diff --git a/src/Net/Listener.h b/src/Net/Listener.h index ca19947..3805403 100644 --- a/src/Net/Listener.h +++ b/src/Net/Listener.h @@ -23,16 +23,17 @@ #include "IPAddress.h" #include -#include -#include #include #include +#include + namespace Mad { namespace Net { class ServerConnection; +// TODO XXX Thread-safeness XXX class Listener { private: std::string x905CertFile, x905KeyFile; @@ -41,10 +42,15 @@ class Listener { gnutls_dh_params_t dh_params; - std::list connections; + std::map > connections; + + sigc::signal signal; void acceptHandler(int); + void connectHandler(ServerConnection *con); + void disconnectHandler(ServerConnection *con); + // Prevent shallow copy Listener(const Listener &o); Listener& operator=(const Listener &o); @@ -53,7 +59,7 @@ class Listener { Listener(const std::string &x905CertFile0, const std::string &x905KeyFile0, const IPAddress &address0 = IPAddress()) throw(Exception); virtual ~Listener(); - ServerConnection* getConnection(); + sigc::signal signalNewConnection() const {return signal;} }; } diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp index 97e2ae4..aa8042e 100644 --- a/src/Net/ServerConnection.cpp +++ b/src/Net/ServerConnection.cpp @@ -54,13 +54,15 @@ void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned ConnectionHeader header2 = {'M', 'A', 'D', 0, 0, 1, 1, 0}; - rawSend((uint8_t*)&header2, sizeof(header2)); - enterReceiveLoop(); + + rawSend((uint8_t*)&header2, sizeof(header2)); } ServerConnection::ServerConnection(int sock0, const IPAddress &address, gnutls_dh_params_t dh_params, const std::string &x905CertFile, const std::string &x905KeyFile) : daemon(false) { + gl_rwlock_wrlock(stateLock); + sock = sock0; peer = new IPAddress(address); @@ -73,9 +75,13 @@ ServerConnection::ServerConnection(int sock0, const IPAddress &address, gnutls_d gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, x509_cred); gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t)sock); - FdManager::get()->registerFd(sock, sigc::mem_fun(this, &Connection::sendReceive)); + FdManager::get()->registerFd(sock, sigc::mem_fun(this, &ServerConnection::sendReceive)); + + state = CONNECT; + + gl_rwlock_unlock(stateLock); - handshake(); + updateEvents(); } } diff --git a/src/Net/ThreadManager.cpp b/src/Net/ThreadManager.cpp index 9eb965d..3495196 100644 --- a/src/Net/ThreadManager.cpp +++ b/src/Net/ThreadManager.cpp @@ -96,40 +96,6 @@ void ThreadManager::pushWork(const sigc::slot &newWork) { gl_lock_unlock(workLock); } -void ThreadManager::pushIO(const sigc::slot &newIO) { - gl_lock_lock(ioLock); - - ioQueue.push(newIO); - - if(!hasIO) { - hasIO = true; - ignore_value(write(ioNotifyPipe[1], "", 1)); - } - - gl_lock_unlock(ioLock); -} - -void ThreadManager::runIO() { - gl_lock_lock(ioLock); - - // Empty the pipe - char buf; - ignore_value(read(ioNotifyPipe[0], &buf, 1)); - hasIO = false; - - while(!ioQueue.empty()) { - sigc::slot handler = ioQueue.front(); - ioQueue.pop(); - gl_lock_unlock(ioLock); - - handler(); - - gl_lock_lock(ioLock); - } - - gl_lock_unlock(ioLock); -} - void ThreadManager::doInit() { gl_lock_init(threadLock); @@ -138,17 +104,6 @@ void ThreadManager::doInit() { gl_lock_init(workLock); gl_cond_init(workCond); - gl_lock_init(ioLock); - hasIO = false; - - // TODO Error handling - pipe(ioNotifyPipe); - - fcntl(ioNotifyPipe[0], F_SETFL, fcntl(ioNotifyPipe[0], F_GETFL) | O_NONBLOCK); - fcntl(ioNotifyPipe[1], F_SETFL, fcntl(ioNotifyPipe[1], F_GETFL) | O_NONBLOCK); - - Net::FdManager::get()->registerFd(ioNotifyPipe[0], sigc::hide(sigc::mem_fun(this, &ThreadManager::runIO)), POLLIN); - running = true; gl_lock_lock(threadLock); @@ -156,6 +111,7 @@ void ThreadManager::doInit() { mainThread = (gl_thread_t)gl_thread_self(); workerThread = gl_thread_create(&ThreadManager::workerStart, 0); loggerThread = gl_thread_create(&ThreadManager::loggerStart, 0); + ioThread = gl_thread_create(&ThreadManager::ioStart, 0); gl_lock_unlock(threadLock); } @@ -192,18 +148,14 @@ void ThreadManager::doDeinit() { } gl_lock_unlock(threadLock); + // IO thread is next + FdManager::get()->stopIOThread(); + gl_thread_join(ioThread, 0); + // Finally, the logger thread has to die Common::LogManager::get()->stopLoggerThread(); gl_thread_join(loggerThread, 0); - // And then we clean everything up - Net::FdManager::get()->unregisterFd(ioNotifyPipe[0]); - - close(ioNotifyPipe[0]); - close(ioNotifyPipe[1]); - - gl_lock_destroy(ioLock); - gl_cond_destroy(workCond); gl_lock_destroy(workLock); diff --git a/src/Net/ThreadManager.h b/src/Net/ThreadManager.h index 9e2b3d3..327ba67 100644 --- a/src/Net/ThreadManager.h +++ b/src/Net/ThreadManager.h @@ -17,11 +17,13 @@ * with this program. If not, see . */ -#ifndef MAD_COMMON_THREADMANAGER_H_ -#define MAD_COMMON_THREADMANAGER_H_ +#ifndef MAD_NET_THREADMANAGER_H_ +#define MAD_NET_THREADMANAGER_H_ #include +#include "FdManager.h" + #include #include @@ -38,7 +40,7 @@ namespace Net { class ThreadManager : public Common::Initializable { private: - gl_thread_t mainThread, workerThread, loggerThread; + gl_thread_t mainThread, workerThread, loggerThread, ioThread; std::set threads; gl_lock_t threadLock; @@ -50,11 +52,6 @@ class ThreadManager : public Common::Initializable { gl_cond_t workCond; std::queue > work; - gl_lock_t ioLock; - bool hasIO; - int ioNotifyPipe[2]; - std::queue > ioQueue; - static ThreadManager threadManager; ThreadManager() {} @@ -69,6 +66,11 @@ class ThreadManager : public Common::Initializable { return 0; } + static void* ioStart(void*) { + FdManager::get()->ioThread(); + return 0; + } + void workerFunc(); void threadFinished(gl_thread_t thread) { @@ -98,9 +100,6 @@ class ThreadManager : public Common::Initializable { void detach(); void pushWork(const sigc::slot &newWork); - void pushIO(const sigc::slot &newIO); - - void runIO(); static ThreadManager* get() { return &threadManager; @@ -110,4 +109,4 @@ class ThreadManager : public Common::Initializable { } } -#endif /* MAD_COMMON_THREADMANAGER_H_ */ +#endif /* MAD_NET_THREADMANAGER_H_ */ diff --git a/src/Server/ConnectionManager.cpp b/src/Server/ConnectionManager.cpp index 9235700..99f5680 100644 --- a/src/Server/ConnectionManager.cpp +++ b/src/Server/ConnectionManager.cpp @@ -41,17 +41,19 @@ #include #include +#include + namespace Mad { namespace Server { ConnectionManager ConnectionManager::connectionManager; -void ConnectionManager::Connection::send(const Net::Packet &packet) { - connection->send(packet); +bool ConnectionManager::Connection::send(const Net::Packet &packet) { + return connection->send(packet); } -ConnectionManager::Connection::Connection(Net::ServerConnection *connection0, ConnectionType type0) -: connection(connection0), type(type0), hostInfo(0) { +ConnectionManager::Connection::Connection(Net::ServerConnection *connection0) +: connection(connection0), type(connection0->isDaemonConnection() ? DAEMON : CLIENT), hostInfo(0) { connection->signalReceive().connect(sigc::mem_fun(this, &Connection::receive)); } @@ -86,7 +88,7 @@ void* ConnectionManager::Connection::getPeerCertificate(size_t *size) const { void ConnectionManager::updateState(Common::HostInfo *hostInfo, Common::HostInfo::State state) { hostInfo->setState(state); - for(std::list::iterator con = connections.begin(); con != connections.end(); ++con) { + for(std::set::iterator con = connections.begin(); con != connections.end(); ++con) { if((*con)->getConnectionType() == Connection::CLIENT) Common::RequestManager::get()->sendRequest(*con, Common::Request::slot_type(), hostInfo->getName(), state); } @@ -147,7 +149,9 @@ bool ConnectionManager::handleConfigEntry(const Common::ConfigEntry &entry, bool void ConnectionManager::configFinished() { if(listenerAddresses.empty()) { try { - listeners.push_back(new Net::Listener(x509CertFile, x509KeyFile)); + Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile); + listener->signalNewConnection().connect(sigc::mem_fun(this, &ConnectionManager::newConnectionHandler)); + listeners.push_back(listener); } catch(Net::Exception &e) { // TODO Log error @@ -156,7 +160,9 @@ void ConnectionManager::configFinished() { else { for(std::vector::const_iterator address = listenerAddresses.begin(); address != listenerAddresses.end(); ++address) { try { - listeners.push_back(new Net::Listener(x509CertFile, x509KeyFile, *address)); + Net::Listener *listener = new Net::Listener(x509CertFile, x509KeyFile, *address); + listener->signalNewConnection().connect(sigc::mem_fun(this, &ConnectionManager::newConnectionHandler)); + listeners.push_back(listener); } catch(Net::Exception &e) { // TODO Log error @@ -165,11 +171,27 @@ void ConnectionManager::configFinished() { } } +void ConnectionManager::newConnectionHandler(Net::ServerConnection *con) { + Connection *connection = new Connection(con); + con->signalDisconnected().connect(sigc::bind(sigc::mem_fun(this, &ConnectionManager::disconnectHandler), connection)); + connections.insert(connection); + + Common::RequestManager::get()->registerConnection(connection); +} + +void ConnectionManager::disconnectHandler(Connection *con) { + if(con->isIdentified()) + updateState(con->getHostInfo(), Common::HostInfo::INACTIVE); + + connections.erase(con); + + Common::RequestManager::get()->unregisterConnection(con); + delete con; +} + void ConnectionManager::doInit() { Common::RequestManager::get()->setServer(true); - Net::Connection::init(); - Common::RequestManager::get()->registerPacketType("AuthGSSAPI"); Common::RequestManager::get()->registerPacketType("DaemonCommand"); Common::RequestManager::get()->registerPacketType("DaemonFSInfo"); @@ -184,7 +206,7 @@ void ConnectionManager::doInit() { } void ConnectionManager::doDeinit() { - for(std::list::iterator con = connections.begin(); con != connections.end(); ++con) + for(std::set::iterator con = connections.begin(); con != connections.end(); ++con) delete *con; @@ -199,38 +221,6 @@ void ConnectionManager::doDeinit() { Common::RequestManager::get()->unregisterPacketType("GetUserInfo"); Common::RequestManager::get()->unregisterPacketType("ListUsers"); Common::RequestManager::get()->unregisterPacketType("Log"); - - Net::Connection::deinit(); -} - -void ConnectionManager::run() { - // TODO Logging - - Net::FdManager::get()->run(); - - for(std::list::iterator con = connections.begin(); con != connections.end();) { - if(!(*con)->isConnected()) { - if((*con)->isIdentified()) - updateState((*con)->getHostInfo(), Common::HostInfo::INACTIVE); - - Common::RequestManager::get()->unregisterConnection(*con); - delete *con; - connections.erase(con++); - } - else - ++con; - } - - for(std::list::iterator listener = listeners.begin(); listener != listeners.end(); ++listener) { - Net::ServerConnection *con; - - while((con = (*listener)->getConnection()) != 0) { - Connection *connection = new Connection(con, - con->isDaemonConnection() ? Connection::DAEMON : Connection::CLIENT); - connections.push_back(connection); - Common::RequestManager::get()->registerConnection(connection); - } - } } Common::Connection* ConnectionManager::getDaemonConnection(const std::string &name) const throw (Net::Exception&) { @@ -244,7 +234,7 @@ Common::Connection* ConnectionManager::getDaemonConnection(const std::string &na } if(hostInfo->getState() != Common::HostInfo::INACTIVE) { - for(std::list::const_iterator it = connections.begin(); it != connections.end(); ++it) { + for(std::set::const_iterator it = connections.begin(); it != connections.end(); ++it) { if((*it)->getHostInfo() == hostInfo) { return *it; } diff --git a/src/Server/ConnectionManager.h b/src/Server/ConnectionManager.h index 62ecc37..691d51f 100644 --- a/src/Server/ConnectionManager.h +++ b/src/Server/ConnectionManager.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -55,10 +56,10 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa Common::HostInfo *hostInfo; protected: - virtual void send(const Net::Packet &packet); + virtual bool send(const Net::Packet &packet); public: - Connection(Net::ServerConnection *connection0, ConnectionType type0); + Connection(Net::ServerConnection *connection0); virtual ~Connection(); bool isConnected() const; @@ -91,7 +92,7 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa std::vector listenerAddresses; std::list listeners; - std::list connections; + std::set connections; std::map daemonInfo; @@ -103,6 +104,9 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa ConnectionManager() {} + void newConnectionHandler(Net::ServerConnection *con); + void disconnectHandler(Connection *con); + protected: virtual bool handleConfigEntry(const Common::ConfigEntry &entry, bool handled); virtual void configFinished(); @@ -115,8 +119,6 @@ class ConnectionManager : public Common::Configurable, public Common::Initializa return &connectionManager; } - void run(); - Common::Connection* getDaemonConnection(const std::string &name) const throw (Net::Exception&); std::string getDaemonName(const Common::Connection *con) const throw (Net::Exception&); diff --git a/src/mad-server.cpp b/src/mad-server.cpp index 59d8949..6fb0e21 100644 --- a/src/mad-server.cpp +++ b/src/mad-server.cpp @@ -18,8 +18,6 @@ */ #include "Common/ConfigManager.h" -#include "Common/LogManager.h" -#include "Common/Logger.h" #include "Common/ModuleManager.h" #include "Net/ThreadManager.h" #include "Server/ConnectionManager.h" @@ -49,7 +47,7 @@ int main() { Common::ConfigManager::get()->finish(); while(true) - Server::ConnectionManager::get()->run(); + sleep(1000); Common::Initializable::deinit(); diff --git a/src/mad.cpp b/src/mad.cpp index bc150e2..1a504f6 100644 --- a/src/mad.cpp +++ b/src/mad.cpp @@ -17,7 +17,6 @@ * with this program. If not, see . */ -#include "Net/Connection.h" #include "Net/FdManager.h" #include "Net/IPAddress.h" #include "Net/ThreadManager.h" @@ -43,8 +42,6 @@ static void requestFinished(const Common::Request&) { } int main() { - Net::Connection::init(); - Net::ThreadManager::get()->init(); Common::ModuleManager::get()->loadModule("FileLogger"); @@ -63,7 +60,7 @@ int main() { connection->connect(Net::IPAddress("127.0.0.1"), true); while(connection->isConnecting()) - Net::FdManager::get()->run(); + usleep(100000); Common::RequestManager::get()->registerConnection(connection); @@ -77,7 +74,7 @@ int main() { Common::RequestManager::get()->sendRequest(connection, sigc::ptr_fun(requestFinished), "test"); while(connection->isConnected()) - Net::FdManager::get()->run(); + usleep(100000); Common::LogManager::get()->unregisterLogger(networkLogger); @@ -95,7 +92,5 @@ int main() { Common::Initializable::deinit(); - Net::Connection::deinit(); - return 0; } diff --git a/src/madc.cpp b/src/madc.cpp index 46455b0..e6fc3f0 100644 --- a/src/madc.cpp +++ b/src/madc.cpp @@ -17,22 +17,16 @@ * with this program. If not, see . */ -#include "Net/Connection.h" #include "Net/FdManager.h" #include "Net/IPAddress.h" #include "Net/ThreadManager.h" #include "Common/ClientConnection.h" #include "Common/ConfigManager.h" -#include "Common/LogManager.h" -#include "Common/Logger.h" #include "Common/RequestManager.h" #include "Client/CommandParser.h" #include "Client/InformationManager.h" #include -#include -#include -#include #include #include @@ -79,8 +73,6 @@ int main(int argc, char *argv[]) { std::exit(1); } - Net::Connection::init(); - Net::ThreadManager::get()->init(); Client::InformationManager::get()->init(); @@ -94,7 +86,7 @@ int main(int argc, char *argv[]) { std::cerr << "Connecting to " << argv[1] << "..." << std::flush; while(connection->isConnecting()) - Net::FdManager::get()->run(); + usleep(100000); std::cerr << " connected." << std::endl; @@ -105,7 +97,7 @@ int main(int argc, char *argv[]) { Client::InformationManager::get()->updateDaemonList(connection); while(Client::InformationManager::get()->isUpdating()) - Net::FdManager::get()->run(); + usleep(100000); std::cerr << " done." << std::endl << std::endl; @@ -117,7 +109,7 @@ int main(int argc, char *argv[]) { activateReadline(); while(connection->isConnected()) - Net::FdManager::get()->run(); + usleep(100000); Net::FdManager::get()->unregisterFd(STDIN_FILENO); @@ -131,7 +123,5 @@ int main(int argc, char *argv[]) { Common::Initializable::deinit(); - Net::Connection::deinit(); - return 0; } -- cgit v1.2.3