From 0e38acdaff7ef753f1d4e140eec9dbaec6f7a047 Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Wed, 25 Jun 2008 23:31:48 +0200 Subject: Nicht-blockierende E/A benutzen --- src/Net/ClientConnection.cpp | 76 +++++++++++++++----------------------------- src/Net/ClientConnection.h | 29 +++-------------- src/Net/Connection.cpp | 44 ++++++++++++++++++++++--- src/Net/Connection.h | 40 +++++++++++++++++------ src/Net/ServerConnection.cpp | 58 ++++++++------------------------- src/Net/ServerConnection.h | 22 ++----------- 6 files changed, 116 insertions(+), 153 deletions(-) diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 1523b44..ca12f9f 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -22,18 +22,11 @@ #include #include #include -#include +#include namespace Mad { namespace Net { -void ClientConnection::sendConnectionHeader(bool daemon) { - ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1}; - - rawSend(reinterpret_cast(&header), sizeof(header)); - rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler)); -} - void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned long length) { if(length != sizeof(ConnectionHeader)) // Error... disconnect @@ -49,14 +42,20 @@ void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned // Unsupported protocol... disconnect return; - connecting = false; enterReceiveLoop(); } -void ClientConnection::connect(const IPAddress &address, bool daemon) throw(ConnectionException) { - const int kx_list[] = {GNUTLS_KX_ANON_DH, 0}; +void ClientConnection::connectionHeader() { + ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1}; - if(connected) + rawSend(reinterpret_cast(&header), sizeof(header)); + rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler)); +} + +void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(ConnectionException) { + daemon = daemon0; + + if(isConnected()) disconnect(); sock = socket(PF_INET, SOCK_STREAM, 0); @@ -68,58 +67,33 @@ void ClientConnection::connect(const IPAddress &address, bool daemon) throw(Conn if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) { close(sock); delete peer; + peer = 0; throw ConnectionException("connect()", std::strerror(errno)); } - connected = true; - connecting = true; - - gnutls_init(&session, GNUTLS_CLIENT); + // Set non-blocking flag + int flags = fcntl(sock, F_GETFL, 0); - gnutls_set_default_priority(session); - gnutls_kx_set_priority(session, kx_list); - - gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); - - gnutls_transport_set_ptr(session, reinterpret_cast(sock)); - - int ret = gnutls_handshake(session); - if(ret < 0) { - disconnect(); + if(flags < 0) { + close(sock); - throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret)); + throw ConnectionException("fcntl()", std::strerror(errno)); } - sendConnectionHeader(daemon); -} - -void ClientConnection::disconnect() { - if(!connected) - return; - - gnutls_bye(session, GNUTLS_SHUT_RDWR); + fcntl(sock, F_SETFL, flags | O_NONBLOCK); - shutdown(sock, SHUT_RDWR); - close(sock); - - gnutls_deinit(session); + gnutls_init(&session, GNUTLS_CLIENT); - delete peer; + gnutls_set_default_priority(session); - connected = false; -} - -bool ClientConnection::dataPending() const { - if(!connected) - return false; + const int kx_list[] = {GNUTLS_KX_ANON_DH, 0}; + gnutls_kx_set_priority(session, kx_list); - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); + gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); - struct timeval timeout = {0, 0}; + gnutls_transport_set_ptr(session, reinterpret_cast(sock)); - return (select(sock + 1, &fds, NULL, NULL, &timeout) == 1); + handshake(); } } diff --git a/src/Net/ClientConnection.h b/src/Net/ClientConnection.h index c738b1e..18b1a02 100644 --- a/src/Net/ClientConnection.h +++ b/src/Net/ClientConnection.h @@ -30,44 +30,25 @@ class IPAddress; class ClientConnection : public Connection { private: - bool connected; - IPAddress *peer; - - bool connecting; - - int sock; - gnutls_session_t session; gnutls_anon_client_credentials_t anoncred; - void sendConnectionHeader(bool daemon); + bool daemon; + void connectionHeaderReceiveHandler(const void *data, unsigned long length); protected: - virtual gnutls_session_t& getSession() { - return session; - } + virtual void connectionHeader(); public: - ClientConnection() : connected(false), connecting(false) { + ClientConnection() : daemon(0) { gnutls_anon_allocate_client_credentials(&anoncred); } virtual ~ClientConnection() { - if(connected) - disconnect(); - gnutls_anon_free_client_credentials(anoncred); } - void connect(const IPAddress &address, bool daemon = false) throw(ConnectionException); - void disconnect(); - - virtual bool dataPending() const; - - virtual bool isConnected() const {return connected;} - virtual const IPAddress* getPeer() const {return peer;} - - virtual bool isConnecting() const {return connecting;} + void connect(const IPAddress &address, bool daemon0 = false) throw(ConnectionException); }; } diff --git a/src/Net/Connection.cpp b/src/Net/Connection.cpp index eb7a55f..e7be313 100644 --- a/src/Net/Connection.cpp +++ b/src/Net/Connection.cpp @@ -18,11 +18,29 @@ */ #include "Connection.h" +#include "IPAddress.h" #include +#include namespace Mad { namespace Net { +void Connection::handshake() { + state = HANDSHAKE; + + int ret = gnutls_handshake(session); + if(ret < 0) { + if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) + return; + + // Error... disconnect + return; + } + + state = CONNECTION_HEADER; + connectionHeader(); +} + void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { if(length != sizeof(Packet::Data)) return; // Error... disconnect? @@ -35,6 +53,7 @@ void Connection::packetHeaderReceiveHandler(const void *data, unsigned long leng enterReceiveLoop(); } else { + state = PACKET_DATA; rawReceive(header.length, sigc::mem_fun(this, &Connection::packetDataReceiveHandler)); } } @@ -52,13 +71,10 @@ void Connection::doReceive() { if(!isConnected()) return; - if(!dataPending()) - return; - if(receiveComplete()) return; - ssize_t ret = gnutls_record_recv(getSession(), transR.data+transR.transmitted, transR.length-transR.transmitted); + ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) @@ -103,7 +119,7 @@ void Connection::doSend() { return; while(!sendQueueEmpty()) { - ssize_t ret = gnutls_record_send(getSession(), transS.front().data+transS.front().transmitted, + ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, transS.front().length-transS.front().transmitted); if(ret < 0) { @@ -134,5 +150,23 @@ bool Connection::rawSend(const unsigned char *data, unsigned long length) { return true; } +void Connection::disconnect() { + if(!isConnected()) + return; + + gnutls_bye(session, GNUTLS_SHUT_RDWR); + + shutdown(sock, SHUT_RDWR); + close(sock); + + gnutls_deinit(session); + + if(peer) + delete peer; + peer = 0; + + state = DISCONNECTED; +} + } } diff --git a/src/Net/Connection.h b/src/Net/Connection.h index 381ed17..a5bda5a 100644 --- a/src/Net/Connection.h +++ b/src/Net/Connection.h @@ -42,6 +42,10 @@ class Connection { sigc::slot notify; }; + enum State { + DISCONNECTED, HANDSHAKE, CONNECTION_HEADER, PACKET_HEADER, PACKET_DATA + } state; + Transmission transR; std::queue transS; @@ -72,23 +76,33 @@ class Connection { unsigned char protVerMax; }; - virtual gnutls_session_t& getSession() = 0; + int sock; + gnutls_session_t session; + + IPAddress *peer; - bool rawReceive(unsigned long length, - const sigc::slot ¬ify); + void handshake(); + virtual void connectionHeader() = 0; + + bool rawReceive(unsigned long length, const sigc::slot ¬ify); bool rawSend(const unsigned char *data, unsigned long length); bool enterReceiveLoop() { + state = PACKET_HEADER; + return rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); } public: - Connection() { + Connection() : state(DISCONNECTED), peer(0) { transR.length = transR.transmitted = 0; transR.data = 0; } virtual ~Connection() { + if(isConnected()) + disconnect(); + if(transR.data) delete [] transR.data; @@ -96,24 +110,30 @@ class Connection { delete [] transS.front().data; transS.pop(); } - } - virtual bool isConnected() const = 0; - virtual const IPAddress* getPeer() const = 0; + bool isConnected() const {return (state != DISCONNECTED);} + bool isConnecting() { + return (state == HANDSHAKE || state == CONNECTION_HEADER); + } - virtual bool isConnecting() const = 0; + const IPAddress* getPeer() {return peer;} - virtual bool dataPending() const = 0; + void disconnect(); bool send(const Packet &packet) { - if(isConnecting()) + if(!isConnected() || isConnecting()) return false; return rawSend(reinterpret_cast(packet.getRawData()), packet.getRawDataLength()); } void sendReceive() { + if(state == HANDSHAKE) { + handshake(); + return; + } + doReceive(); doSend(); } diff --git a/src/Net/ServerConnection.cpp b/src/Net/ServerConnection.cpp index dd25af1..4ad6215 100644 --- a/src/Net/ServerConnection.cpp +++ b/src/Net/ServerConnection.cpp @@ -22,7 +22,7 @@ #include #include #include -#include +#include namespace Mad { namespace Net { @@ -54,14 +54,13 @@ void ServerConnection::connectionHeaderReceiveHandler(const void *data, unsigned rawSend(reinterpret_cast(&header2), sizeof(header2)); - connecting = false; enterReceiveLoop(); } void ServerConnection::listen(const IPAddress &address) throw(ConnectionException) { const int kx_list[] = {GNUTLS_KX_ANON_DH, 0}; - if(connected) + if(isConnected()) disconnect(); int listen_sock = socket(PF_INET, SOCK_STREAM, 0); @@ -97,10 +96,18 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio close(listen_sock); - *peer = IPAddress(address); + // Set non-blocking flag + int flags = fcntl(sock, F_GETFL, 0); + + if(flags < 0) { + close(sock); + + throw ConnectionException("fcntl()", std::strerror(errno)); + } + + fcntl(sock, F_SETFL, flags | O_NONBLOCK); - connected = true; - connecting = true; + *peer = IPAddress(address); gnutls_init(&session, GNUTLS_SERVER); @@ -110,45 +117,8 @@ void ServerConnection::listen(const IPAddress &address) throw(ConnectionExceptio gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); gnutls_transport_set_ptr(session, reinterpret_cast(sock)); - - int ret = gnutls_handshake(session); - if(ret < 0) { - disconnect(); - - throw ConnectionException("gnutls_handshake()", gnutls_strerror(ret)); - } - - rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ServerConnection::connectionHeaderReceiveHandler)); -} - -void ServerConnection::disconnect() { - if(!connected) - return; - - gnutls_bye(session, GNUTLS_SHUT_RDWR); - - shutdown(sock, SHUT_RDWR); - close(sock); - - gnutls_deinit(session); - - delete peer; - - connected = false; - connecting = false; -} - -bool ServerConnection::dataPending() const { - if(!connected) - return false; - - fd_set fds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - - struct timeval timeout = {0, 0}; - return (select(sock + 1, &fds, NULL, NULL, &timeout) == 1); + handshake(); } } diff --git a/src/Net/ServerConnection.h b/src/Net/ServerConnection.h index cd983e0..0fd5f0b 100644 --- a/src/Net/ServerConnection.h +++ b/src/Net/ServerConnection.h @@ -29,27 +29,22 @@ namespace Net { class ServerConnection : public Connection { private: - bool connected; IPAddress *peer; - bool connecting; - bool daemon; - int sock; - gnutls_session_t session; gnutls_anon_server_credentials_t anoncred; gnutls_dh_params_t dh_params; void connectionHeaderReceiveHandler(const void *data, unsigned long length); protected: - virtual gnutls_session_t& getSession() { - return session; + virtual void connectionHeader() { + rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ServerConnection::connectionHeaderReceiveHandler)); } public: - ServerConnection() : connected(false), connecting(false), daemon(false) { + ServerConnection() : daemon(false) { gnutls_anon_allocate_server_credentials(&anoncred); gnutls_dh_params_init(&dh_params); @@ -58,22 +53,11 @@ class ServerConnection : public Connection { } virtual ~ServerConnection() { - if(connected) - disconnect(); - gnutls_dh_params_deinit(dh_params); gnutls_anon_free_server_credentials(anoncred); } void listen(const IPAddress &address) throw(ConnectionException); - void disconnect(); - - virtual bool dataPending() const; - - virtual bool isConnected() const {return connected;} - virtual const IPAddress* getPeer() const {return peer;} - - virtual bool isConnecting() const {return connecting;} bool isDaemonConnection() const {return daemon;} }; -- cgit v1.2.3