diff options
Diffstat (limited to 'src/Net')
-rw-r--r-- | src/Net/ClientConnection.cpp | 42 | ||||
-rw-r--r-- | src/Net/Listener.cpp | 62 |
2 files changed, 54 insertions, 50 deletions
diff --git a/src/Net/ClientConnection.cpp b/src/Net/ClientConnection.cpp index 7fdfd68..a3447a0 100644 --- a/src/Net/ClientConnection.cpp +++ b/src/Net/ClientConnection.cpp @@ -31,69 +31,73 @@ void ClientConnection::connectionHeaderReceiveHandler(const void *data, unsigned if(length != sizeof(ConnectionHeader)) // Error... disconnect return; - + const ConnectionHeader *header = reinterpret_cast<const ConnectionHeader*>(data); - + if(header->m != 'M' || header->a != 'A' || header->d != 'D') // Error... disconnect return; - + if(header->protVerMin != 1) // Unsupported protocol... disconnect return; - + enterReceiveLoop(); } void ClientConnection::connectionHeader() { ConnectionHeader header = {'M', 'A', 'D', daemon ? 'D' : 'C', 0, 1, 1, 1}; - + rawSend(reinterpret_cast<unsigned char*>(&header), sizeof(header)); rawReceive(sizeof(ConnectionHeader), sigc::mem_fun(this, &ClientConnection::connectionHeaderReceiveHandler)); } void ClientConnection::connect(const IPAddress &address, bool daemon0) throw(ConnectionException) { daemon = daemon0; - + if(isConnected()) disconnect(); - + sock = socket(PF_INET, SOCK_STREAM, 0); if(sock < 0) throw ConnectionException("socket()", std::strerror(errno)); - + peer = new IPAddress(address); - + if(::connect(sock, peer->getSockAddr(), peer->getSockAddrLength()) < 0) { close(sock); delete peer; peer = 0; throw ConnectionException("connect()", std::strerror(errno)); } - + // Set non-blocking flag int flags = fcntl(sock, F_GETFL, 0); - + if(flags < 0) { close(sock); - + throw ConnectionException("fcntl()", std::strerror(errno)); } - + fcntl(sock, F_SETFL, flags | O_NONBLOCK); - + + // Don't linger + struct linger linger = {1, 0}; + setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); + gnutls_init(&session, GNUTLS_CLIENT); - + gnutls_set_default_priority(session); - + const int kx_list[] = {GNUTLS_KX_ANON_DH, 0}; gnutls_kx_set_priority(session, kx_list); - + gnutls_credentials_set(session, GNUTLS_CRD_ANON, anoncred); - + gnutls_transport_set_lowat(session, 0); gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(sock)); - + handshake(); } diff --git a/src/Net/Listener.cpp b/src/Net/Listener.cpp index 981b3c7..3b2e3d6 100644 --- a/src/Net/Listener.cpp +++ b/src/Net/Listener.cpp @@ -32,36 +32,36 @@ Listener::Listener(const IPAddress &address0) throw(ConnectionException) : address(address0) { gnutls_dh_params_init(&dh_params); gnutls_dh_params_generate2(dh_params, 768); - + sock = socket(PF_INET, SOCK_STREAM, 0); - + if(sock < 0) throw ConnectionException("socket()", std::strerror(errno)); - + // Set non-blocking flag int flags = fcntl(sock, F_GETFL, 0); - + if(flags < 0) { close(sock); - + throw ConnectionException("fcntl()", std::strerror(errno)); } - + fcntl(sock, F_SETFL, flags | O_NONBLOCK); - - // Set reuse address - flags = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &flags, sizeof(flags)); - + + // Don't linger + struct linger linger = {1, 0}; + setsockopt(sock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); + if(bind(sock, address.getSockAddr(), address.getSockAddrLength()) < 0) { close(sock); - + throw ConnectionException("bind()", std::strerror(errno)); } - + if(listen(sock, 64) < 0) { close(sock); - + throw ConnectionException("listen()", std::strerror(errno)); } } @@ -71,66 +71,66 @@ Listener::~Listener() { (*con)->disconnect(); delete *con; } - + shutdown(sock, SHUT_RDWR); close(sock); - + gnutls_dh_params_deinit(dh_params); } std::vector<struct pollfd> Listener::getPollfds() const { std::vector<struct pollfd> pollfds; - + struct pollfd fd = {sock, POLLIN, 0}; pollfds.push_back(fd); - + for(std::list<ServerConnection*>::const_iterator con = connections.begin(); con != connections.end(); ++con) pollfds.push_back((*con)->getPollfd()); - + return pollfds; } ServerConnection* Listener::getConnection(const std::map<int,const short*> &pollfdMap) { // TODO: Logging - + int sd; struct sockaddr_in sa; socklen_t addrlen = sizeof(sa); - - + + while((sd = accept(sock, reinterpret_cast<struct sockaddr*>(&sa), &addrlen)) >= 0) { connections.push_back(new ServerConnection(sd, IPAddress(sa), dh_params)); - + addrlen = sizeof(sa); } - + for(std::list<ServerConnection*>::iterator con = connections.begin(); con != connections.end(); ++con) { std::map<int,const short*>::const_iterator events = pollfdMap.find((*con)->getSocket()); - + if(events != pollfdMap.end()) (*con)->sendReceive(*events->second); else (*con)->sendReceive(); } - + for(std::list<ServerConnection*>::iterator con = connections.begin(); con != connections.end();) { if(!(*con)->isConnected()) { delete *con; connections.erase(con++); // Erase unincremented iterator - + continue; } - + if(!(*con)->isConnecting()) { ServerConnection *connection = *con; connections.erase(con); - + return connection; } - + ++con; } - + return 0; } |