diff --git a/libraries/networking/src/LimitedNodeList.h b/libraries/networking/src/LimitedNodeList.h index 43393ef69c..0cbe9668b3 100644 --- a/libraries/networking/src/LimitedNodeList.h +++ b/libraries/networking/src/LimitedNodeList.h @@ -219,6 +219,8 @@ public: udt::Socket::StatsVector sampleStatsForAllConnections() { return _nodeSocket.sampleStatsForAllConnections(); } + void setConnectionMaxBandwidth(int maxBandwidth) { _nodeSocket.setConnectionMaxBandwidth(maxBandwidth); } + public slots: void reset(); void eraseAllNodes(); diff --git a/libraries/networking/src/udt/CongestionControl.cpp b/libraries/networking/src/udt/CongestionControl.cpp index 8eff5e3a01..1d1a6628fe 100644 --- a/libraries/networking/src/udt/CongestionControl.cpp +++ b/libraries/networking/src/udt/CongestionControl.cpp @@ -20,13 +20,19 @@ using namespace std::chrono; static const double USECS_PER_SECOND = 1000000.0; +void CongestionControl::setMaxBandwidth(int maxBandwidth) { + _maxBandwidth = maxBandwidth; + setPacketSendPeriod(_packetSendPeriod); +} + void CongestionControl::setPacketSendPeriod(double newSendPeriod) { Q_ASSERT_X(newSendPeriod >= 0, "CongestionControl::setPacketPeriod", "Can not set a negative packet send period"); - - if (_maxBandwidth > 0) { + + auto maxBandwidth = _maxBandwidth.load(); + if (maxBandwidth > 0) { // anytime the packet send period is about to be increased, make sure it stays below the minimum period, // calculated based on the maximum desired bandwidth - double minPacketSendPeriod = USECS_PER_SECOND / (((double) _maxBandwidth) / _mss); + double minPacketSendPeriod = USECS_PER_SECOND / (((double) maxBandwidth) / _mss); _packetSendPeriod = std::max(newSendPeriod, minPacketSendPeriod); } else { _packetSendPeriod = newSendPeriod; diff --git a/libraries/networking/src/udt/CongestionControl.h b/libraries/networking/src/udt/CongestionControl.h index 3a5c8d0d00..8297b5f6bd 100644 --- a/libraries/networking/src/udt/CongestionControl.h +++ b/libraries/networking/src/udt/CongestionControl.h @@ -12,6 +12,7 @@ #ifndef hifi_CongestionControl_h #define hifi_CongestionControl_h +#include #include #include #include @@ -37,6 +38,7 @@ public: virtual ~CongestionControl() {} int synInterval() const { return _synInterval; } + void setMaxBandwidth(int maxBandwidth); virtual void init() {} virtual void onACK(SequenceNumber ackNum) {} @@ -49,7 +51,6 @@ protected: void setMSS(int mss) { _mss = mss; } void setMaxCongestionWindowSize(int window) { _maxCongestionWindowSize = window; } void setBandwidth(int bandwidth) { _bandwidth = bandwidth; } - void setMaxBandwidth(int maxBandwidth) { _maxBandwidth = maxBandwidth; } virtual void setInitialSendSequenceNumber(SequenceNumber seqNum) = 0; void setSendCurrentSequenceNumber(SequenceNumber seqNum) { _sendCurrSeqNum = seqNum; } void setReceiveRate(int rate) { _receiveRate = rate; } @@ -60,7 +61,7 @@ protected: double _congestionWindowSize { 16.0 }; // Congestion window size, in packets int _bandwidth { 0 }; // estimated bandwidth, packets per second - int _maxBandwidth { -1 }; // Maximum desired bandwidth, packets per second + std::atomic _maxBandwidth { -1 }; // Maximum desired bandwidth, bytes per second double _maxCongestionWindowSize { 0.0 }; // maximum cwnd size, in packets int _mss { 0 }; // Maximum Packet Size, including all packet headers diff --git a/libraries/networking/src/udt/Connection.cpp b/libraries/networking/src/udt/Connection.cpp index f75a9535f5..e5f3508b81 100644 --- a/libraries/networking/src/udt/Connection.cpp +++ b/libraries/networking/src/udt/Connection.cpp @@ -80,6 +80,10 @@ void Connection::resetRTT() { _rttVariance = _rtt / 2; } +void Connection::setMaxBandwidth(int maxBandwidth) { + _congestionControl->setMaxBandwidth(maxBandwidth); +} + SendQueue& Connection::getSendQueue() { if (!_sendQueue) { diff --git a/libraries/networking/src/udt/Connection.h b/libraries/networking/src/udt/Connection.h index 8d80e736af..4f5a8793e7 100644 --- a/libraries/networking/src/udt/Connection.h +++ b/libraries/networking/src/udt/Connection.h @@ -76,6 +76,8 @@ public: HifiSockAddr getDestination() const { return _destination; } + void setMaxBandwidth(int maxBandwidth); + signals: void packetSent(); void connectionInactive(const HifiSockAddr& sockAddr); diff --git a/libraries/networking/src/udt/Socket.cpp b/libraries/networking/src/udt/Socket.cpp index 1eb7c04331..e9af1577fb 100644 --- a/libraries/networking/src/udt/Socket.cpp +++ b/libraries/networking/src/udt/Socket.cpp @@ -176,7 +176,9 @@ Connection& Socket::findOrCreateConnection(const HifiSockAddr& sockAddr) { auto it = _connectionsHash.find(sockAddr); if (it == _connectionsHash.end()) { - auto connection = std::unique_ptr(new Connection(this, sockAddr, _ccFactory->create())); + auto congestionControl = _ccFactory->create(); + congestionControl->setMaxBandwidth(_maxBandwidth); + auto connection = std::unique_ptr(new Connection(this, sockAddr, std::move(congestionControl))); // we queue the connection to cleanup connection in case it asks for it during its own rate control sync QObject::connect(connection.get(), &Connection::connectionInactive, this, &Socket::cleanupConnection); @@ -350,6 +352,17 @@ void Socket::setCongestionControlFactory(std::unique_ptrsynInterval(); } + +void Socket::setConnectionMaxBandwidth(int maxBandwidth) { + qInfo() << "Setting socket's maximum bandwith to" << maxBandwidth << ". (" + << _connectionsHash.size() << "live connections)"; + _maxBandwidth = maxBandwidth; + for (auto& pair : _connectionsHash) { + auto& connection = pair.second; + connection->setMaxBandwidth(_maxBandwidth); + } +} + ConnectionStats::Stats Socket::sampleStatsForConnection(const HifiSockAddr& destination) { auto it = _connectionsHash.find(destination); if (it != _connectionsHash.end()) { diff --git a/libraries/networking/src/udt/Socket.h b/libraries/networking/src/udt/Socket.h index 88db8e3d86..424158045f 100644 --- a/libraries/networking/src/udt/Socket.h +++ b/libraries/networking/src/udt/Socket.h @@ -72,6 +72,7 @@ public: { _unfilteredHandlers[senderSockAddr] = handler; } void setCongestionControlFactory(std::unique_ptr ccFactory); + void setConnectionMaxBandwidth(int maxBandwidth); void messageReceived(std::unique_ptr packet); void messageFailed(Connection* connection, Packet::MessageNumber messageNumber); @@ -109,8 +110,10 @@ private: std::unordered_map _unreliableSequenceNumbers; std::unordered_map> _connectionsHash; - int _synInterval = 10; // 10ms - QTimer* _synTimer; + int _synInterval { 10 }; // 10ms + QTimer* _synTimer { nullptr }; + + int _maxBandwidth { -1 }; std::unique_ptr _ccFactory { new CongestionControlFactory() };