diff --git a/animation-server/src/AnimationServer.cpp b/animation-server/src/AnimationServer.cpp index 78966da2b5..3151445794 100644 --- a/animation-server/src/AnimationServer.cpp +++ b/animation-server/src/AnimationServer.cpp @@ -830,7 +830,7 @@ void AnimationServer::readPendingDatagrams() { receivedPacket.resize(nodeList->getNodeSocket().pendingDatagramSize()); nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), nodeSockAddr.getAddressPointer(), nodeSockAddr.getPortPointer()); - if (packetVersionMatch(receivedPacket)) { + if (nodeList->packetVersionAndHashMatch(receivedPacket)) { if (packetTypeForPacket(receivedPacket) == PacketTypeJurisdiction) { int headerBytes = numBytesForPacketHeader(receivedPacket); // PacketType_JURISDICTION, first byte is the node type... diff --git a/assignment-client/src/AssignmentClient.cpp b/assignment-client/src/AssignmentClient.cpp index b294f48fac..8add7d90f8 100644 --- a/assignment-client/src/AssignmentClient.cpp +++ b/assignment-client/src/AssignmentClient.cpp @@ -111,7 +111,7 @@ void AssignmentClient::readPendingDatagrams() { nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer()); - if (packetVersionMatch(receivedPacket)) { + if (nodeList->packetVersionAndHashMatch(receivedPacket)) { if (_currentAssignment) { // have the threaded current assignment handle this datagram QMetaObject::invokeMethod(_currentAssignment, "processDatagram", Qt::QueuedConnection, diff --git a/data-server/src/DataServer.cpp b/data-server/src/DataServer.cpp index 43fc52fb06..97228b67a3 100644 --- a/data-server/src/DataServer.cpp +++ b/data-server/src/DataServer.cpp @@ -66,7 +66,7 @@ void DataServer::readPendingDatagrams() { PacketType requestType = packetTypeForPacket(receivedPacket); if ((requestType == PacketTypeDataServerPut || requestType == PacketTypeDataServerGet) && - packetVersionMatch(receivedPacket)) { + receivedPacket[numBytesArithmeticCodingFromBuffer(receivedPacket.data())] == versionForPacketType(requestType)) { QDataStream packetStream(receivedPacket); int numReceivedHeaderBytes = numBytesForPacketHeader(receivedPacket); diff --git a/domain-server/src/DomainServer.cpp b/domain-server/src/DomainServer.cpp index 6ce71cae01..2a98c9b8b0 100644 --- a/domain-server/src/DomainServer.cpp +++ b/domain-server/src/DomainServer.cpp @@ -236,7 +236,7 @@ void DomainServer::readAvailableDatagrams() { nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer()); - if (packetVersionMatch(receivedPacket)) { + if (nodeList->packetVersionAndHashMatch(receivedPacket)) { PacketType requestType = packetTypeForPacket(receivedPacket); if (requestType == PacketTypeDomainListRequest) { diff --git a/interface/src/DatagramProcessor.cpp b/interface/src/DatagramProcessor.cpp index 36f03c39f0..6271ef5fde 100644 --- a/interface/src/DatagramProcessor.cpp +++ b/interface/src/DatagramProcessor.cpp @@ -40,7 +40,7 @@ void DatagramProcessor::processDatagrams() { _packetCount++; _byteCount += incomingPacket.size(); - if (packetVersionMatch(incomingPacket)) { + if (nodeList->packetVersionAndHashMatch(incomingPacket)) { // only process this packet if we have a match on the packet version switch (packetTypeForPacket(incomingPacket)) { case PacketTypeTransmitterData: diff --git a/interface/src/VoxelPacketProcessor.cpp b/interface/src/VoxelPacketProcessor.cpp index b5ec4247f2..dce391e587 100644 --- a/interface/src/VoxelPacketProcessor.cpp +++ b/interface/src/VoxelPacketProcessor.cpp @@ -48,7 +48,7 @@ void VoxelPacketProcessor::processPacket(const SharedNodePointer& sendingNode, c wasStatsPacket = true; if (messageLength > statsMessageLength) { mutablePacket = mutablePacket.mid(statsMessageLength); - if (!packetVersionMatch(packet)) { + if (!NodeList::getInstance()->packetVersionAndHashMatch(packet)) { return; // bail since piggyback data doesn't match our versioning } } else { diff --git a/libraries/shared/src/NodeList.cpp b/libraries/shared/src/NodeList.cpp index 46234c510f..ba890564ce 100644 --- a/libraries/shared/src/NodeList.cpp +++ b/libraries/shared/src/NodeList.cpp @@ -80,30 +80,64 @@ NodeList::~NodeList() { clear(); } -qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode, - const HifiSockAddr& overridenSockAddr) { +bool NodeList::packetVersionAndHashMatch(const QByteArray& packet) { + // currently this just checks if the version in the packet matches our return from versionForPacketType + // may need to be expanded in the future for types and versions that take > than 1 byte - // setup the MD5 hash for source verification in the header - int numBytesPacketHeader = numBytesForPacketHeader(datagram); - QByteArray dataSecretHash = QCryptographicHash::hash(datagram.mid(numBytesPacketHeader) - + destinationNode->getConnectionSecret().toRfc4122(), - QCryptographicHash::Md5); - QByteArray datagramWithHash = datagram; - datagramWithHash.replace(numBytesPacketHeader - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH, dataSecretHash); + if (packet[1] != versionForPacketType(packetTypeForPacket(packet)) + && packetTypeForPacket(packet) != PacketTypeStunResponse) { + PacketType mismatchType = packetTypeForPacket(packet); + int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data()); + + qDebug() << "Packet version mismatch on" << packetTypeForPacket(packet) << "- Sender" + << uuidFromPacketHeader(packet) << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but" + << qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected."; + } - // if we don't have an ovveriden address, assume they want to send to the node's active socket - const HifiSockAddr* destinationSockAddr = &overridenSockAddr; - if (overridenSockAddr.isNull()) { - if (getNodeActiveSocketOrPing(destinationNode)) { - // use the node's active socket as the destination socket - destinationSockAddr = destinationNode->getActiveSocket(); + if (packetTypeForPacket(packet) != PacketTypeDomainList && packetTypeForPacket(packet) != PacketTypeDomainListRequest) { + // figure out which node this is from + SharedNodePointer sendingNode = sendingNodeForPacket(packet); + if (sendingNode) { + // check if the md5 hash in the header matches the hash we would expect + if (hashFromPacketHeader(packet) == hashForPacketAndConnectionUUID(packet, sendingNode->getConnectionSecret())) { + return true; + } else { + qDebug() << "Packet hash mismatch" << packetTypeForPacket(packet) << "received from known node with UUID" + << uuidFromPacketHeader(packet); + } } else { - // we don't have a socket to send to, return 0 - return 0; + qDebug() << "Packet of type" << packetTypeForPacket(packet) << "received from unknown node with UUID" + << uuidFromPacketHeader(packet); } } - return _nodeSocket.writeDatagram(datagramWithHash, destinationSockAddr->getAddress(), destinationSockAddr->getPort()); + return false; +} + +qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode, + const HifiSockAddr& overridenSockAddr) { + if (destinationNode) { + // if we don't have an ovveriden address, assume they want to send to the node's active socket + const HifiSockAddr* destinationSockAddr = &overridenSockAddr; + if (overridenSockAddr.isNull()) { + if (getNodeActiveSocketOrPing(destinationNode)) { + // use the node's active socket as the destination socket + destinationSockAddr = destinationNode->getActiveSocket(); + } else { + // we don't have a socket to send to, return 0 + return 0; + } + } + + QByteArray datagramCopy = datagram; + // setup the MD5 hash for source verification in the header + replaceHashInPacketGivenConnectionUUID(datagramCopy, destinationNode->getConnectionSecret()); + + return _nodeSocket.writeDatagram(datagramCopy, destinationSockAddr->getAddress(), destinationSockAddr->getPort()); + } + + // didn't have a destinationNode to send to, return 0 + return 0; } qint64 NodeList::writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode, @@ -191,8 +225,11 @@ void NodeList::processNodeData(const HifiSockAddr& senderSockAddr, const QByteAr } case PacketTypePing: { // send back a reply - QByteArray replyPacket = constructPingReplyPacket(packet); - writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr); + if (sendingNodeForPacket(packet)) { + QByteArray replyPacket = constructPingReplyPacket(packet); + writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr); + } + break; } case PacketTypePingReply: { @@ -716,9 +753,9 @@ void NodeList::activateSocketFromNodeCommunication(const QByteArray& packet, con // if this is a local or public ping then we can activate a socket // we do nothing with agnostic pings, those are simply for timing - if (pingType == PingType::Local) { + if (pingType == PingType::Local && sendingNode->getActiveSocket() != &sendingNode->getLocalSocket()) { sendingNode->activateLocalSocket(); - } else if (pingType == PingType::Public) { + } else if (pingType == PingType::Public && !sendingNode->getActiveSocket()) { sendingNode->activatePublicSocket(); } } diff --git a/libraries/shared/src/NodeList.h b/libraries/shared/src/NodeList.h index 2e89395d3e..408459e51f 100644 --- a/libraries/shared/src/NodeList.h +++ b/libraries/shared/src/NodeList.h @@ -84,6 +84,8 @@ public: QUdpSocket& getNodeSocket() { return _nodeSocket; } + bool packetVersionAndHashMatch(const QByteArray& packet); + qint64 writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode, const HifiSockAddr& overridenSockAddr = HifiSockAddr()); qint64 writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode, diff --git a/libraries/shared/src/PacketHeaders.cpp b/libraries/shared/src/PacketHeaders.cpp index 16e1b3ce78..7193b6f648 100644 --- a/libraries/shared/src/PacketHeaders.cpp +++ b/libraries/shared/src/PacketHeaders.cpp @@ -85,26 +85,6 @@ int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionU return position - packet; } -bool packetVersionMatch(const QByteArray& packet) { - // currently this just checks if the version in the packet matches our return from versionForPacketType - // may need to be expanded in the future for types and versions that take > than 1 byte - - if (packet[1] == versionForPacketType(packetTypeForPacket(packet)) || packetTypeForPacket(packet) == PacketTypeStunResponse) { - return true; - } else { - PacketType mismatchType = packetTypeForPacket(packet); - int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data()); - - QUuid nodeUUID = uuidFromPacketHeader(packet); - - qDebug() << "Packet mismatch on" << packetTypeForPacket(packet) << "- Sender" - << nodeUUID << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but" - << qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected."; - - return false; - } -} - int numBytesForPacketHeader(const QByteArray& packet) { // returns the number of bytes used for the type, version, and UUID return numBytesArithmeticCodingFromBuffer(packet.data()) + NUM_STATIC_HEADER_BYTES; @@ -124,6 +104,20 @@ QUuid uuidFromPacketHeader(const QByteArray& packet) { NUM_BYTES_RFC4122_UUID)); } +QByteArray hashFromPacketHeader(const QByteArray& packet) { + return packet.mid(NUM_STATIC_HEADER_BYTES - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH); +} + +QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID) { + return QCryptographicHash::hash(packet.mid(numBytesForPacketHeader(packet)) + + connectionUUID.toRfc4122(), QCryptographicHash::Md5); +} + +void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID) { + packet.replace(numBytesForPacketHeader(packet) - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH, + hashForPacketAndConnectionUUID(packet, connectionUUID)); +} + PacketType packetTypeForPacket(const QByteArray& packet) { return (PacketType) arithmeticCodingValueFromBuffer(packet.data()); } diff --git a/libraries/shared/src/PacketHeaders.h b/libraries/shared/src/PacketHeaders.h index ad669daece..c1a5a34114 100644 --- a/libraries/shared/src/PacketHeaders.h +++ b/libraries/shared/src/PacketHeaders.h @@ -70,17 +70,20 @@ QByteArray byteArrayWithPopluatedHeader(PacketType type, const QUuid& connection int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID = nullUUID); int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID = nullUUID); -bool packetVersionMatch(const QByteArray& packet); - int numBytesForPacketHeader(const QByteArray& packet); int numBytesForPacketHeader(const char* packet); int numBytesForPacketHeaderGivenPacketType(PacketType type); QUuid uuidFromPacketHeader(const QByteArray& packet); +QByteArray hashFromPacketHeader(const QByteArray& packet); +QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID); +void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID); + PacketType packetTypeForPacket(const QByteArray& packet); PacketType packetTypeForPacket(const char* packet); int arithmeticCodingValueFromBuffer(const char* checkValue); +int numBytesArithmeticCodingFromBuffer(const char* checkValue); #endif