From 2e0cc158de4500a7b16b304d44ca242ef0e4d02c Mon Sep 17 00:00:00 2001
From: Seth Alves <seth.alves@gmail.com>
Date: Fri, 16 Sep 2016 17:21:28 -0700
Subject: [PATCH] ice test-client uses stun server to get public address

---
 libraries/networking/src/LimitedNodeList.cpp | 170 ++++++++++---------
 libraries/networking/src/LimitedNodeList.h   |   8 +-
 tools/ice-client/src/ICEClientApp.cpp        | 106 +++++++++---
 3 files changed, 177 insertions(+), 107 deletions(-)

diff --git a/libraries/networking/src/LimitedNodeList.cpp b/libraries/networking/src/LimitedNodeList.cpp
index 5aa31efea4..ec4b2c3573 100644
--- a/libraries/networking/src/LimitedNodeList.cpp
+++ b/libraries/networking/src/LimitedNodeList.cpp
@@ -745,8 +745,32 @@ void LimitedNodeList::removeSilentNodes() {
 const uint32_t RFC_5389_MAGIC_COOKIE = 0x2112A442;
 const int NUM_BYTES_STUN_HEADER = 20;
 
-void LimitedNodeList::sendSTUNRequest() {
 
+void LimitedNodeList::makeSTUNRequestPacket(char* stunRequestPacket) {
+    int packetIndex = 0;
+
+    const uint32_t RFC_5389_MAGIC_COOKIE_NETWORK_ORDER = htonl(RFC_5389_MAGIC_COOKIE);
+
+    // leading zeros + message type
+    const uint16_t REQUEST_MESSAGE_TYPE = htons(0x0001);
+    memcpy(stunRequestPacket + packetIndex, &REQUEST_MESSAGE_TYPE, sizeof(REQUEST_MESSAGE_TYPE));
+    packetIndex += sizeof(REQUEST_MESSAGE_TYPE);
+
+    // message length (no additional attributes are included)
+    uint16_t messageLength = 0;
+    memcpy(stunRequestPacket + packetIndex, &messageLength, sizeof(messageLength));
+    packetIndex += sizeof(messageLength);
+
+    memcpy(stunRequestPacket + packetIndex, &RFC_5389_MAGIC_COOKIE_NETWORK_ORDER, sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER));
+    packetIndex += sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER);
+
+    // transaction ID (random 12-byte unsigned integer)
+    const uint NUM_TRANSACTION_ID_BYTES = 12;
+    QUuid randomUUID = QUuid::createUuid();
+    memcpy(stunRequestPacket + packetIndex, randomUUID.toRfc4122().data(), NUM_TRANSACTION_ID_BYTES);
+}
+
+void LimitedNodeList::sendSTUNRequest() {
     if (!_stunSockAddr.getAddress().isNull()) {
         const int NUM_INITIAL_STUN_REQUESTS_BEFORE_FAIL = 10;
 
@@ -762,36 +786,14 @@ void LimitedNodeList::sendSTUNRequest() {
         }
 
         char stunRequestPacket[NUM_BYTES_STUN_HEADER];
-
-        int packetIndex = 0;
-
-        const uint32_t RFC_5389_MAGIC_COOKIE_NETWORK_ORDER = htonl(RFC_5389_MAGIC_COOKIE);
-
-        // leading zeros + message type
-        const uint16_t REQUEST_MESSAGE_TYPE = htons(0x0001);
-        memcpy(stunRequestPacket + packetIndex, &REQUEST_MESSAGE_TYPE, sizeof(REQUEST_MESSAGE_TYPE));
-        packetIndex += sizeof(REQUEST_MESSAGE_TYPE);
-
-        // message length (no additional attributes are included)
-        uint16_t messageLength = 0;
-        memcpy(stunRequestPacket + packetIndex, &messageLength, sizeof(messageLength));
-        packetIndex += sizeof(messageLength);
-
-        memcpy(stunRequestPacket + packetIndex, &RFC_5389_MAGIC_COOKIE_NETWORK_ORDER, sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER));
-        packetIndex += sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER);
-
-        // transaction ID (random 12-byte unsigned integer)
-        const uint NUM_TRANSACTION_ID_BYTES = 12;
-        QUuid randomUUID = QUuid::createUuid();
-        memcpy(stunRequestPacket + packetIndex, randomUUID.toRfc4122().data(), NUM_TRANSACTION_ID_BYTES);
-
+        makeSTUNRequestPacket(stunRequestPacket);
         flagTimeForConnectionStep(ConnectionStep::SendSTUNRequest);
-
         _nodeSocket.writeDatagram(stunRequestPacket, sizeof(stunRequestPacket), _stunSockAddr);
     }
 }
 
-void LimitedNodeList::processSTUNResponse(std::unique_ptr<udt::BasePacket> packet) {
+bool LimitedNodeList::parseSTUNResponse(udt::BasePacket* packet,
+                                        QHostAddress& newPublicAddress, uint16_t& newPublicPort) {
     // check the cookie to make sure this is actually a STUN response
     // and read the first attribute and make sure it is a XOR_MAPPED_ADDRESS
     const int NUM_BYTES_MESSAGE_TYPE_AND_LENGTH = 4;
@@ -803,71 +805,79 @@ void LimitedNodeList::processSTUNResponse(std::unique_ptr<udt::BasePacket> packe
 
     if (memcmp(packet->getData() + NUM_BYTES_MESSAGE_TYPE_AND_LENGTH,
                &RFC_5389_MAGIC_COOKIE_NETWORK_ORDER,
-               sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER)) == 0) {
+               sizeof(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER)) != 0) {
+        return false;
+    }
 
-        // enumerate the attributes to find XOR_MAPPED_ADDRESS_TYPE
-        while (attributeStartIndex < packet->getDataSize()) {
+    // enumerate the attributes to find XOR_MAPPED_ADDRESS_TYPE
+    while (attributeStartIndex < packet->getDataSize()) {
+        if (memcmp(packet->getData() + attributeStartIndex, &XOR_MAPPED_ADDRESS_TYPE, sizeof(XOR_MAPPED_ADDRESS_TYPE)) == 0) {
+            const int NUM_BYTES_STUN_ATTR_TYPE_AND_LENGTH = 4;
+            const int NUM_BYTES_FAMILY_ALIGN = 1;
+            const uint8_t IPV4_FAMILY_NETWORK_ORDER = htons(0x01) >> 8;
 
-            if (memcmp(packet->getData() + attributeStartIndex, &XOR_MAPPED_ADDRESS_TYPE, sizeof(XOR_MAPPED_ADDRESS_TYPE)) == 0) {
-                const int NUM_BYTES_STUN_ATTR_TYPE_AND_LENGTH = 4;
-                const int NUM_BYTES_FAMILY_ALIGN = 1;
-                const uint8_t IPV4_FAMILY_NETWORK_ORDER = htons(0x01) >> 8;
+            int byteIndex = attributeStartIndex + NUM_BYTES_STUN_ATTR_TYPE_AND_LENGTH + NUM_BYTES_FAMILY_ALIGN;
 
-                int byteIndex = attributeStartIndex + NUM_BYTES_STUN_ATTR_TYPE_AND_LENGTH + NUM_BYTES_FAMILY_ALIGN;
+            uint8_t addressFamily = 0;
+            memcpy(&addressFamily, packet->getData() + byteIndex, sizeof(addressFamily));
 
-                uint8_t addressFamily = 0;
-                memcpy(&addressFamily, packet->getData() + byteIndex, sizeof(addressFamily));
+            byteIndex += sizeof(addressFamily);
 
-                byteIndex += sizeof(addressFamily);
+            if (addressFamily == IPV4_FAMILY_NETWORK_ORDER) {
+                // grab the X-Port
+                uint16_t xorMappedPort = 0;
+                memcpy(&xorMappedPort, packet->getData() + byteIndex, sizeof(xorMappedPort));
 
-                if (addressFamily == IPV4_FAMILY_NETWORK_ORDER) {
-                    // grab the X-Port
-                    uint16_t xorMappedPort = 0;
-                    memcpy(&xorMappedPort, packet->getData() + byteIndex, sizeof(xorMappedPort));
+                newPublicPort = ntohs(xorMappedPort) ^ (ntohl(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER) >> 16);
 
-                    uint16_t newPublicPort = ntohs(xorMappedPort) ^ (ntohl(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER) >> 16);
+                byteIndex += sizeof(xorMappedPort);
 
-                    byteIndex += sizeof(xorMappedPort);
+                // grab the X-Address
+                uint32_t xorMappedAddress = 0;
+                memcpy(&xorMappedAddress, packet->getData() + byteIndex, sizeof(xorMappedAddress));
 
-                    // grab the X-Address
-                    uint32_t xorMappedAddress = 0;
-                    memcpy(&xorMappedAddress, packet->getData() + byteIndex, sizeof(xorMappedAddress));
+                uint32_t stunAddress = ntohl(xorMappedAddress) ^ ntohl(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER);
 
-                    uint32_t stunAddress = ntohl(xorMappedAddress) ^ ntohl(RFC_5389_MAGIC_COOKIE_NETWORK_ORDER);
-
-                    QHostAddress newPublicAddress(stunAddress);
-
-                    if (newPublicAddress != _publicSockAddr.getAddress() || newPublicPort != _publicSockAddr.getPort()) {
-                        _publicSockAddr = HifiSockAddr(newPublicAddress, newPublicPort);
-
-                        qCDebug(networking, "New public socket received from STUN server is %s:%hu",
-                               _publicSockAddr.getAddress().toString().toLocal8Bit().constData(),
-                               _publicSockAddr.getPort());
-
-                        if (!_hasCompletedInitialSTUN) {
-                            // if we're here we have definitely completed our initial STUN sequence
-                            stopInitialSTUNUpdate(true);
-                        }
-
-                        emit publicSockAddrChanged(_publicSockAddr);
-
-                        flagTimeForConnectionStep(ConnectionStep::SetPublicSocketFromSTUN);
-                    }
-
-                    // we're done reading the packet so we can return now
-                    return;
-                }
-            } else {
-                // push forward attributeStartIndex by the length of this attribute
-                const int NUM_BYTES_ATTRIBUTE_TYPE = 2;
-
-                uint16_t attributeLength = 0;
-                memcpy(&attributeLength, packet->getData() + attributeStartIndex + NUM_BYTES_ATTRIBUTE_TYPE,
-                       sizeof(attributeLength));
-                attributeLength = ntohs(attributeLength);
-
-                attributeStartIndex += NUM_BYTES_MESSAGE_TYPE_AND_LENGTH + attributeLength;
+                // QHostAddress newPublicAddress(stunAddress);
+                newPublicAddress = QHostAddress(stunAddress);
+                return true;
             }
+        } else {
+            // push forward attributeStartIndex by the length of this attribute
+            const int NUM_BYTES_ATTRIBUTE_TYPE = 2;
+
+            uint16_t attributeLength = 0;
+            memcpy(&attributeLength, packet->getData() + attributeStartIndex + NUM_BYTES_ATTRIBUTE_TYPE,
+                   sizeof(attributeLength));
+            attributeLength = ntohs(attributeLength);
+
+            attributeStartIndex += NUM_BYTES_MESSAGE_TYPE_AND_LENGTH + attributeLength;
+        }
+    }
+    return false;
+}
+
+
+void LimitedNodeList::processSTUNResponse(std::unique_ptr<udt::BasePacket> packet) {
+    uint16_t newPublicPort;
+    QHostAddress newPublicAddress;
+    if (parseSTUNResponse(packet.get(), newPublicAddress, newPublicPort)) {
+
+        if (newPublicAddress != _publicSockAddr.getAddress() || newPublicPort != _publicSockAddr.getPort()) {
+            _publicSockAddr = HifiSockAddr(newPublicAddress, newPublicPort);
+
+            qCDebug(networking, "New public socket received from STUN server is %s:%hu",
+                    _publicSockAddr.getAddress().toString().toLocal8Bit().constData(),
+                    _publicSockAddr.getPort());
+
+            if (!_hasCompletedInitialSTUN) {
+                // if we're here we have definitely completed our initial STUN sequence
+                stopInitialSTUNUpdate(true);
+            }
+
+            emit publicSockAddrChanged(_publicSockAddr);
+
+            flagTimeForConnectionStep(ConnectionStep::SetPublicSocketFromSTUN);
         }
     }
 }
diff --git a/libraries/networking/src/LimitedNodeList.h b/libraries/networking/src/LimitedNodeList.h
index 1a3599e226..e74a6c49f8 100644
--- a/libraries/networking/src/LimitedNodeList.h
+++ b/libraries/networking/src/LimitedNodeList.h
@@ -146,6 +146,7 @@ public:
                                       const NodePermissions& permissions = DEFAULT_AGENT_PERMISSIONS,
                                       const QUuid& connectionSecret = QUuid());
 
+    static bool parseSTUNResponse(udt::BasePacket* packet, QHostAddress& newPublicAddress, uint16_t& newPublicPort);
     bool hasCompletedInitialSTUN() const { return _hasCompletedInitialSTUN; }
 
     const HifiSockAddr& getLocalSockAddr() const { return _localSockAddr; }
@@ -232,6 +233,9 @@ public:
     bool packetVersionMatch(const udt::Packet& packet);
     bool isPacketVerified(const udt::Packet& packet);
 
+    static void makeSTUNRequestPacket(char* stunRequestPacket);
+
+
 public slots:
     void reset();
     void eraseAllNodes();
@@ -275,7 +279,7 @@ protected:
     LimitedNodeList(int socketListenPort = INVALID_PORT, int dtlsListenPort = INVALID_PORT);
     LimitedNodeList(LimitedNodeList const&) = delete; // Don't implement, needed to avoid copies of singleton
     void operator=(LimitedNodeList const&) = delete; // Don't implement, needed to avoid copies of singleton
-    
+
     qint64 sendPacket(std::unique_ptr<NLPacket> packet, const Node& destinationNode,
                       const HifiSockAddr& overridenSockAddr);
     qint64 writePacket(const NLPacket& packet, const HifiSockAddr& destinationSockAddr,
@@ -284,7 +288,7 @@ protected:
     void fillPacketHeader(const NLPacket& packet, const QUuid& connectionSecret = QUuid());
 
     void setLocalSocket(const HifiSockAddr& sockAddr);
-    
+
     bool packetSourceAndHashMatchAndTrackBandwidth(const udt::Packet& packet);
     void processSTUNResponse(std::unique_ptr<udt::BasePacket> packet);
 
diff --git a/tools/ice-client/src/ICEClientApp.cpp b/tools/ice-client/src/ICEClientApp.cpp
index f0000c78cf..8614fbf960 100644
--- a/tools/ice-client/src/ICEClientApp.cpp
+++ b/tools/ice-client/src/ICEClientApp.cpp
@@ -16,7 +16,10 @@
 
 #include "ICEClientApp.h"
 
-ICEClientApp::ICEClientApp(int argc, char* argv[]) : QCoreApplication(argc, argv) {
+ICEClientApp::ICEClientApp(int argc, char* argv[]) :
+    QCoreApplication(argc, argv),
+    _stunSockAddr(STUN_SERVER_HOSTNAME, STUN_SERVER_PORT)
+{
     // parse command-line
     QCommandLineParser parser;
     parser.setApplicationDescription("High Fidelity ICE client");
@@ -80,7 +83,7 @@ ICEClientApp::ICEClientApp(int argc, char* argv[]) : QCoreApplication(argc, argv
 
     qDebug() << "ICE-server address is" << _iceServerAddr;
 
-    _state = 0;
+    setState(lookUpStunServer);
 
     QTimer* doTimer = new QTimer(this);
     connect(doTimer, &QTimer::timeout, this, &ICEClientApp::doSomething);
@@ -91,43 +94,71 @@ ICEClientApp::~ICEClientApp() {
     delete _socket;
 }
 
+void ICEClientApp::setState(int newState) {
+    // qDebug() << "state: " << _state << " --> " << newState;
+    _state = newState;
+}
+
 void ICEClientApp::doSomething() {
     if (_actionMax > 0 && _actionCount >= _actionMax) {
+        // time to stop.
         QMetaObject::invokeMethod(this, "quit", Qt::QueuedConnection);
-    }
 
-    if (_state == 0) {
+    } else if (_state == lookUpStunServer) {
+        // lookup STUN server address
+        if (!_stunSockAddr.getAddress().isNull()) {
+            qDebug() << "stun server is" << _stunSockAddr;
+            setState(sendStunRequestPacket);
+        }
+
+    } else if (_state == sendStunRequestPacket) {
+        // send STUN request packet
+
         _domainServerPeerSet = false;
         unsigned int localPort = 0;
+        delete _socket;
         _socket = new udt::Socket();
         _socket->bind(QHostAddress::AnyIPv4, localPort);
-        _socket->setPacketHandler([this](std::unique_ptr<udt::Packet> packet) { processPacket(std::move(packet));  });
+        _socket->setPacketHandler([this](std::unique_ptr<udt::Packet> packet) { processPacket(std::move(packet)); });
+        _socket->addUnfilteredHandler(_stunSockAddr,
+                                      [this](std::unique_ptr<udt::BasePacket> packet) {
+                                          processSTUNResponse(std::move(packet));
+                                      });
 
         qDebug() << "local port is" << _socket->localPort();
         _localSockAddr = HifiSockAddr("127.0.0.1", _socket->localPort());
         _publicSockAddr = HifiSockAddr("127.0.0.1", _socket->localPort());
 
-        // QUuid peerID = QUuid("75cd162a-53dc-4292-aaa5-1304ab1bb0f2");
+        const int NUM_BYTES_STUN_HEADER = 20;
+        char stunRequestPacket[NUM_BYTES_STUN_HEADER];
+        LimitedNodeList::makeSTUNRequestPacket(stunRequestPacket);
+        qDebug() << "sending STUN request";
+        _socket->writeDatagram(stunRequestPacket, sizeof(stunRequestPacket), _stunSockAddr);
+
+        setState(waitForStunResponse);
+
+    } else if (_state == talkToIceServer) {
         QUuid peerID;
         if (_domainID == QUuid()) {
-            // pick a random domain-id
+            // pick a random domain-id which will fail
             peerID = QUuid::createUuid();
-            _state = 2;
+            setState(pause0);
         } else {
             // use the domain UUID given on the command-line
             peerID = _domainID;
-            _state = 1;
+            setState(waitForIceReply);
         }
         _sessionUUID = QUuid::createUuid();
+        qDebug() << "I am" << _sessionUUID;
 
         sendPacketToIceServer(PacketType::ICEServerQuery, _iceServerAddr, _sessionUUID, peerID);
 
         _actionCount++;
-    } else if (_state == 2) {
-        _state = 3;
-    } else if (_state == 3) {
+    } else if (_state == pause0) {
+        setState(pause1);
+    } else if (_state == pause1) {
         qDebug() << "";
-        _state = 0;
+        setState(sendStunRequestPacket);
         delete _socket;
         _socket = nullptr;
     }
@@ -158,7 +189,7 @@ void ICEClientApp::icePingDomainServer() {
         return;
     }
 
-    qDebug() << "ice-pinging domain-server";
+    qDebug() << "ice-pinging domain-server: " << _domainServerPeer;
 
     auto localPingPacket = LimitedNodeList::constructICEPingPacket(PingType::Local, _sessionUUID);
     _socket->writePacket(*localPingPacket, _domainServerPeer.getLocalSocket());
@@ -167,29 +198,52 @@ void ICEClientApp::icePingDomainServer() {
     _socket->writePacket(*publicPingPacket, _domainServerPeer.getPublicSocket());
 }
 
+void ICEClientApp::processSTUNResponse(std::unique_ptr<udt::BasePacket> packet) {
+    qDebug() << "got stun response";
+    if (_state != waitForStunResponse) {
+        qDebug() << "got unexpected stun response";
+        return;
+    }
+
+    uint16_t newPublicPort;
+    QHostAddress newPublicAddress;
+    if (LimitedNodeList::parseSTUNResponse(packet.get(), newPublicAddress, newPublicPort)) {
+        _publicSockAddr = HifiSockAddr(newPublicAddress, newPublicPort);
+        qDebug() << "My public address is" << _publicSockAddr;
+        setState(talkToIceServer);
+    }
+}
+
 
 void ICEClientApp::processPacket(std::unique_ptr<udt::Packet> packet) {
-    auto nlPacket = NLPacket::fromBase(std::move(packet));
+    std::unique_ptr<NLPacket> nlPacket = NLPacket::fromBase(std::move(packet));
 
     if (nlPacket->getPayloadSize() < NLPacket::localHeaderSize(PacketType::ICEServerHeartbeat)) {
         qDebug() << "got a short packet.";
         return;
     }
 
+    qDebug() << "here" << nlPacket->getType();
+
     QSharedPointer<ReceivedMessage> message = QSharedPointer<ReceivedMessage>::create(*nlPacket);
     const HifiSockAddr& senderAddr = message->getSenderSockAddr();
 
     if (nlPacket->getType() == PacketType::ICEServerPeerInformation) {
         QDataStream iceResponseStream(message->getMessage());
-        iceResponseStream >> _domainServerPeer;
-        _domainServerPeerSet = true;
+        if (!_domainServerPeerSet) {
+            iceResponseStream >> _domainServerPeer;
+            qDebug() << "got ICEServerPeerInformation from" << _domainServerPeer;
+            _domainServerPeerSet = true;
 
-        icePingDomainServer();
-        _pingDomainTimer = new QTimer(this);
-        connect(_pingDomainTimer, &QTimer::timeout, this, &ICEClientApp::icePingDomainServer);
-        _pingDomainTimer->start(1000);
-
-        qDebug() << "got ICEServerPeerInformation from" << _domainServerPeer.getUUID();
+            icePingDomainServer();
+            _pingDomainTimer = new QTimer(this);
+            connect(_pingDomainTimer, &QTimer::timeout, this, &ICEClientApp::icePingDomainServer);
+            _pingDomainTimer->start(500);
+        } else {
+            // NetworkPeer domainServerPeer;
+            // iceResponseStream >> domainServerPeer;
+            // qDebug() << "got repeat ICEServerPeerInformation from" << domainServerPeer;
+        }
 
     } else if (nlPacket->getType() == PacketType::ICEPing) {
         qDebug() << "got packet: " << nlPacket->getType();
@@ -198,14 +252,16 @@ void ICEClientApp::processPacket(std::unique_ptr<udt::Packet> packet) {
 
     } else if (nlPacket->getType() == PacketType::ICEPingReply) {
         qDebug() << "got packet: " << nlPacket->getType();
-        if (_domainServerPeerSet && _state == 1 &&
+        if (_domainServerPeerSet && _state == waitForIceReply &&
             (senderAddr == _domainServerPeer.getLocalSocket() ||
              senderAddr == _domainServerPeer.getPublicSocket())) {
 
             delete _pingDomainTimer;
             _pingDomainTimer = nullptr;
 
-            _state = 2;
+            setState(pause0);
+        } else {
+            qDebug() << "got unexpected ICEPingReply" << senderAddr;
         }
 
     } else {