/* * Connection.cpp * * Copyright (C) 2008 Matthias Schiffer * * This program is free software: you can redistribute it and/or modify it * under the terms of the GNU General Public License as published by the * Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. * See the GNU General Public License for more details. * * You should have received a copy of the GNU General Public License along * with this program. If not, see . */ #include "Connection.h" #include "FdManager.h" #include "IPAddress.h" #include "ThreadManager.h" #include #include #include namespace Mad { namespace Net { Connection::StaticInit Connection::staticInit; Connection::~Connection() { if(_isConnected()) doDisconnect(); if(transR.data) delete [] transR.data; while(!_sendQueueEmpty()) { delete [] transS.front().data; transS.pop(); } gnutls_certificate_free_credentials(x509_cred); gl_rwlock_destroy(stateLock); gl_lock_destroy(sendLock); gl_lock_destroy(receiveLock); if(peer) delete peer; } void Connection::handshake() { gl_rwlock_wrlock(stateLock); if(state != CONNECT) { gl_rwlock_unlock(stateLock); return; } state = HANDSHAKE; gl_rwlock_unlock(stateLock); doHandshake(); } void Connection::bye() { gl_rwlock_wrlock(stateLock); if(state != DISCONNECT) { gl_rwlock_unlock(stateLock); return; } state = BYE; gl_rwlock_unlock(stateLock); doBye(); } void Connection::doHandshake() { gl_rwlock_rdlock(stateLock); if(state != HANDSHAKE) { gl_rwlock_unlock(stateLock); return; } int ret = gnutls_handshake(session); if(ret < 0) { gl_rwlock_unlock(stateLock); if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { updateEvents(); return; } // TODO: Error doDisconnect(); return; } state = CONNECTION_HEADER; gl_rwlock_unlock(stateLock); connectionHeader(); } void Connection::doBye() { if(state != BYE) return; int ret = gnutls_bye(session, GNUTLS_SHUT_RDWR); if(ret < 0) { if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { updateEvents(); return; } // TODO: Error doDisconnect(); return; } doDisconnect(); } void Connection::enterReceiveLoop() { gl_rwlock_wrlock(stateLock); if(!_isConnected() || _isDisconnecting()) { gl_rwlock_unlock(stateLock); return; } if(_isConnecting()) ThreadManager::get()->pushWork(sigc::mem_fun(connectedSignal, &sigc::signal::emit)); state = PACKET_HEADER; gl_rwlock_unlock(stateLock); rawReceive(sizeof(Packet::Data), sigc::mem_fun(this, &Connection::packetHeaderReceiveHandler)); } void Connection::packetHeaderReceiveHandler(const void *data, unsigned long length) { if(state != PACKET_HEADER) return; if(length != sizeof(Packet::Data)) { // TODO: Error doDisconnect(); return; } header = *(const Packet::Data*)data; if(header.length == 0) { ThreadManager::get()->pushWork(sigc::bind(sigc::mem_fun(receiveSignal, &sigc::signal::emit), Packet(ntohs(header.requestId)))); enterReceiveLoop(); } else { state = PACKET_DATA; rawReceive(ntohs(header.length), sigc::mem_fun(this, &Connection::packetDataReceiveHandler)); } } void Connection::packetDataReceiveHandler(const void *data, unsigned long length) { if(state != PACKET_DATA) return; if(length != ntohs(header.length)) { // TODO: Error doDisconnect(); return; } ThreadManager::get()->pushWork(sigc::bind(sigc::mem_fun(receiveSignal, &sigc::signal::emit), Packet(ntohs(header.requestId), data, length))); enterReceiveLoop(); } void Connection::doReceive() { if(!isConnected()) return; gl_lock_lock(receiveLock); if(_receiveComplete()) { gl_lock_unlock(receiveLock); return; } ssize_t ret = gnutls_record_recv(session, transR.data+transR.transmitted, transR.length-transR.transmitted); if(ret < 0) { gl_lock_unlock(receiveLock); if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; // TODO: Error doDisconnect(); return; } transR.transmitted += ret; if(_receiveComplete()) { // Save data pointer, as transR.notify might start a new reception uint8_t *data = transR.data; transR.data = 0; gl_lock_unlock(receiveLock); transR.notify(data, transR.length); delete [] data; } else { gl_lock_unlock(receiveLock); } updateEvents(); } bool Connection::rawReceive(unsigned long length, const sigc::slot ¬ify) { if(!isConnected()) return false; gl_lock_lock(receiveLock); if(!_receiveComplete()) { gl_lock_unlock(receiveLock); return false; } transR.data = new uint8_t[length]; transR.length = length; transR.transmitted = 0; transR.notify = notify; gl_lock_unlock(receiveLock); updateEvents(); return true; } void Connection::doSend() { if(!isConnected()) return; gl_lock_lock(sendLock); while(!_sendQueueEmpty()) { ssize_t ret = gnutls_record_send(session, transS.front().data+transS.front().transmitted, transS.front().length-transS.front().transmitted); if(ret < 0) { gl_lock_unlock(sendLock); if(ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) return; // TODO: Error doDisconnect(); return; } transS.front().transmitted += ret; if(transS.front().transmitted == transS.front().length) { delete [] transS.front().data; transS.pop(); } } gl_lock_unlock(sendLock); updateEvents(); } bool Connection::rawSend(const uint8_t *data, unsigned long length) { if(!isConnected()) return false; Transmission trans = {length, 0, new uint8_t[length], sigc::slot()}; std::memcpy(trans.data, data, length); gl_lock_lock(sendLock); transS.push(trans); gl_lock_unlock(sendLock); updateEvents(); return true; } void Connection::sendReceive(short events) { if(events & POLLHUP || events & POLLERR) { doDisconnect(); return; } switch(state) { case CONNECT: handshake(); return; case HANDSHAKE: doHandshake(); return; case DISCONNECT: if(!_sendQueueEmpty()) break; bye(); return; case BYE: doBye(); return; default: break; } if(events & POLLIN) doReceive(); if(events & POLLOUT) doSend(); } bool Connection::send(const Packet &packet) { gl_rwlock_rdlock(stateLock); bool err = (!_isConnected() || _isConnecting() || _isDisconnecting()); gl_rwlock_unlock(stateLock); if(err) return false; return rawSend((const uint8_t*)packet.getRawData(), packet.getRawDataLength()); } void Connection::disconnect() { gl_rwlock_wrlock(stateLock); if(!_isConnected() || _isDisconnecting()) { gl_rwlock_unlock(stateLock); return; } state = DISCONNECT; gl_rwlock_unlock(stateLock); updateEvents(); } void Connection::doDisconnect() { gl_rwlock_wrlock(stateLock); if(_isConnected()) { FdManager::get()->unregisterFd(sock); shutdown(sock, SHUT_RDWR); close(sock); gnutls_deinit(session); ThreadManager::get()->pushWork(sigc::mem_fun(disconnectedSignal, &sigc::signal::emit)); state = DISCONNECTED; } gl_rwlock_unlock(stateLock); } void Connection::updateEvents() { gl_lock_lock(receiveLock); short events = (_receiveComplete() ? 0 : POLLIN); gl_lock_unlock(receiveLock); gl_lock_lock(sendLock); events |= (_sendQueueEmpty() ? 0 : POLLOUT); gl_lock_unlock(sendLock); gl_rwlock_rdlock(stateLock); if(state == HANDSHAKE || state == BYE) events = ((gnutls_record_get_direction(session) == 0) ? POLLIN : POLLOUT); else if(state == CONNECT || state == DISCONNECT) events |= POLLOUT; FdManager::get()->setFdEvents(sock, events); gl_rwlock_unlock(stateLock); } } }