From 5562d7e7a6192872fc08dea311721c2b785e4e2a Mon Sep 17 00:00:00 2001 From: Stephen Birarda Date: Tue, 11 Oct 2016 17:36:29 -0700 Subject: [PATCH] add a filter operator to decide if connections are created --- domain-server/src/DomainServer.cpp | 2 +- libraries/networking/src/LimitedNodeList.cpp | 4 ++ libraries/networking/src/LimitedNodeList.h | 2 + libraries/networking/src/NodeList.cpp | 8 ++++ libraries/networking/src/NodeList.h | 2 + libraries/networking/src/udt/Socket.cpp | 45 ++++++++++++++------ libraries/networking/src/udt/Socket.h | 7 ++- 7 files changed, 54 insertions(+), 16 deletions(-) diff --git a/domain-server/src/DomainServer.cpp b/domain-server/src/DomainServer.cpp index ca428c88fe..8b73b851b2 100644 --- a/domain-server/src/DomainServer.cpp +++ b/domain-server/src/DomainServer.cpp @@ -519,7 +519,7 @@ void DomainServer::setupNodeListAndAssignments() { // add whatever static assignments that have been parsed to the queue addStaticAssignmentsToQueue(); - // set a custum packetVersionMatch as the verify packet operator for the udt::Socket + // set a custom packetVersionMatch as the verify packet operator for the udt::Socket nodeList->setPacketFilterOperator(&DomainServer::packetVersionMatch); } diff --git a/libraries/networking/src/LimitedNodeList.cpp b/libraries/networking/src/LimitedNodeList.cpp index ce555315e8..f2716047a5 100644 --- a/libraries/networking/src/LimitedNodeList.cpp +++ b/libraries/networking/src/LimitedNodeList.cpp @@ -113,6 +113,10 @@ LimitedNodeList::LimitedNodeList(int socketListenPort, int dtlsListenPort) : using std::placeholders::_1; _nodeSocket.setPacketFilterOperator(std::bind(&LimitedNodeList::isPacketVerified, this, _1)); + // set our socketBelongsToNode method as the connection creation filter operator for the udt::Socket + using std::placeholders::_1; + _nodeSocket.setConnectionCreationFilterOperator(std::bind(&LimitedNodeList::sockAddrBelongsToNode, this, _1)); + _packetStatTimer.start(); if (_stunSockAddr.getAddress().isNull()) { diff --git a/libraries/networking/src/LimitedNodeList.h b/libraries/networking/src/LimitedNodeList.h index e74a6c49f8..183bbccb42 100644 --- a/libraries/networking/src/LimitedNodeList.h +++ b/libraries/networking/src/LimitedNodeList.h @@ -299,6 +299,8 @@ protected: void sendPacketToIceServer(PacketType packetType, const HifiSockAddr& iceServerSockAddr, const QUuid& clientID, const QUuid& peerRequestID = QUuid()); + bool sockAddrBelongsToNode(const HifiSockAddr& sockAddr) { return findNodeWithAddr(sockAddr) != SharedNodePointer(); } + QUuid _sessionUUID; NodeHash _nodeHash; mutable QReadWriteLock _nodeMutex; diff --git a/libraries/networking/src/NodeList.cpp b/libraries/networking/src/NodeList.cpp index 27e6f17c33..f2b28a04a7 100644 --- a/libraries/networking/src/NodeList.cpp +++ b/libraries/networking/src/NodeList.cpp @@ -104,6 +104,10 @@ NodeList::NodeList(char newOwnerType, int socketListenPort, int dtlsListenPort) connect(&_domainHandler, SIGNAL(connectedToDomain(QString)), &_keepAlivePingTimer, SLOT(start())); connect(&_domainHandler, &DomainHandler::disconnectedFromDomain, &_keepAlivePingTimer, &QTimer::stop); + // set our sockAddrBelongsToDomainOrNode method as the connection creation filter for the udt::Socket + using std::placeholders::_1; + _nodeSocket.setConnectionCreationFilterOperator(std::bind(&NodeList::sockAddrBelongsToDomainOrNode, this, _1)); + // we definitely want STUN to update our public socket, so call the LNL to kick that off startSTUNPublicSocketUpdate(); @@ -703,6 +707,10 @@ void NodeList::sendKeepAlivePings() { }); } +bool NodeList::sockAddrBelongsToDomainOrNode(const HifiSockAddr& sockAddr) { + return _domainHandler.getSockAddr() == sockAddr || LimitedNodeList::sockAddrBelongsToNode(sockAddr); +} + void NodeList::ignoreNodeBySessionID(const QUuid& nodeID) { // enumerate the nodes to send a reliable ignore packet to each that can leverage it diff --git a/libraries/networking/src/NodeList.h b/libraries/networking/src/NodeList.h index f08c0dbe45..7f98b8c736 100644 --- a/libraries/networking/src/NodeList.h +++ b/libraries/networking/src/NodeList.h @@ -132,6 +132,8 @@ private: void pingPunchForInactiveNode(const SharedNodePointer& node); + bool sockAddrBelongsToDomainOrNode(const HifiSockAddr& sockAddr); + NodeType_t _ownerType; NodeSet _nodeTypesOfInterest; DomainHandler _domainHandler; diff --git a/libraries/networking/src/udt/Socket.cpp b/libraries/networking/src/udt/Socket.cpp index 37ededa55c..0fe78b7d63 100644 --- a/libraries/networking/src/udt/Socket.cpp +++ b/libraries/networking/src/udt/Socket.cpp @@ -171,11 +171,11 @@ qint64 Socket::writePacketList(std::unique_ptr packetList, const Hif } void Socket::writeReliablePacket(Packet* packet, const HifiSockAddr& sockAddr) { - findOrCreateConnection(sockAddr).sendReliablePacket(std::unique_ptr(packet)); + findOrCreateConnection(sockAddr, true)->sendReliablePacket(std::unique_ptr(packet)); } void Socket::writeReliablePacketList(PacketList* packetList, const HifiSockAddr& sockAddr) { - findOrCreateConnection(sockAddr).sendReliablePacketList(std::unique_ptr(packetList)); + findOrCreateConnection(sockAddr, true)->sendReliablePacketList(std::unique_ptr(packetList)); } qint64 Socket::writeDatagram(const char* data, qint64 size, const HifiSockAddr& sockAddr) { @@ -198,10 +198,22 @@ qint64 Socket::writeDatagram(const QByteArray& datagram, const HifiSockAddr& soc return bytesWritten; } -Connection& Socket::findOrCreateConnection(const HifiSockAddr& sockAddr) { +Connection* Socket::findOrCreateConnection(const HifiSockAddr& sockAddr, bool forceCreation) { auto it = _connectionsHash.find(sockAddr); if (it == _connectionsHash.end()) { + // we did not have a matching connection, time to see if we should make one + + if (!forceCreation && _connectionCreationFilterOperator && !_connectionCreationFilterOperator(sockAddr)) { + // we weren't asked to force the creation of a connection + // and the connection creation filter did not tell us to create one +#ifdef UDT_CONNECTION_DEBUG + qCDebug(networking) << "Socket::findOrCreateConnection refusing to create connection for" << sockAddr + << "due to connection creation filter"; +#endif + return nullptr; + } + auto congestionControl = _ccFactory->create(); congestionControl->setMaxBandwidth(_maxBandwidth); auto connection = std::unique_ptr(new Connection(this, sockAddr, std::move(congestionControl))); @@ -216,7 +228,7 @@ Connection& Socket::findOrCreateConnection(const HifiSockAddr& sockAddr) { it = _connectionsHash.insert(it, std::make_pair(sockAddr, std::move(connection))); } - return *it->second; + return it->second.get(); } void Socket::clearConnections() { @@ -292,9 +304,12 @@ void Socket::readPendingDatagrams() { // setup a control packet from the data we just read auto controlPacket = ControlPacket::fromReceivedPacket(std::move(buffer), packetSizeWithHeader, senderSockAddr); - // move this control packet to the matching connection - auto& connection = findOrCreateConnection(senderSockAddr); - connection.processControl(move(controlPacket)); + // move this control packet to the matching connection, if there is one + auto connection = findOrCreateConnection(senderSockAddr); + + if (connection) { + connection->processControl(move(controlPacket)); + } } else { // setup a Packet from the data we just read @@ -304,19 +319,21 @@ void Socket::readPendingDatagrams() { if (!_packetFilterOperator || _packetFilterOperator(*packet)) { if (packet->isReliable()) { // if this was a reliable packet then signal the matching connection with the sequence number - auto& connection = findOrCreateConnection(senderSockAddr); + auto connection = findOrCreateConnection(senderSockAddr); - if (!connection.processReceivedSequenceNumber(packet->getSequenceNumber(), - packet->getDataSize(), - packet->getPayloadSize())) { - // the connection indicated that we should not continue processing this packet + if (!connection || !connection->processReceivedSequenceNumber(packet->getSequenceNumber(), + packet->getDataSize(), + packet->getPayloadSize())) { + // the connection could not be created or indicated that we should not continue processing this packet continue; } } if (packet->isPartOfMessage()) { - auto& connection = findOrCreateConnection(senderSockAddr); - connection.queueReceivedMessagePacket(std::move(packet)); + auto connection = findOrCreateConnection(senderSockAddr); + if (connection) { + connection->queueReceivedMessagePacket(std::move(packet)); + } } else if (_packetHandler) { // call the verified packet callback to let it handle this packet _packetHandler(std::move(packet)); diff --git a/libraries/networking/src/udt/Socket.h b/libraries/networking/src/udt/Socket.h index bc4393d4bd..7c464f8b5e 100644 --- a/libraries/networking/src/udt/Socket.h +++ b/libraries/networking/src/udt/Socket.h @@ -37,6 +37,7 @@ class PacketList; class SequenceNumber; using PacketFilterOperator = std::function; +using ConnectionCreationFilterOperator = std::function; using BasePacketHandler = std::function)>; using PacketHandler = std::function)>; @@ -68,6 +69,8 @@ public: void setPacketHandler(PacketHandler handler) { _packetHandler = handler; } void setMessageHandler(MessageHandler handler) { _messageHandler = handler; } void setMessageFailureHandler(MessageFailureHandler handler) { _messageFailureHandler = handler; } + void setConnectionCreationFilterOperator(ConnectionCreationFilterOperator filterOperator) + { _connectionCreationFilterOperator = filterOperator; } void addUnfilteredHandler(const HifiSockAddr& senderSockAddr, BasePacketHandler handler) { _unfilteredHandlers[senderSockAddr] = handler; } @@ -93,7 +96,8 @@ private slots: private: void setSystemBufferSizes(); - Connection& findOrCreateConnection(const HifiSockAddr& sockAddr); + Connection* findOrCreateConnection(const HifiSockAddr& sockAddr, bool forceCreation = false); + bool socketMatchesNodeOrDomain(const HifiSockAddr& sockAddr); // privatized methods used by UDTTest - they are private since they must be called on the Socket thread ConnectionStats::Stats sampleStatsForConnection(const HifiSockAddr& destination); @@ -109,6 +113,7 @@ private: PacketHandler _packetHandler; MessageHandler _messageHandler; MessageFailureHandler _messageFailureHandler; + ConnectionCreationFilterOperator _connectionCreationFilterOperator; std::unordered_map _unfilteredHandlers; std::unordered_map _unreliableSequenceNumbers;