allow optional sequence numbers in packets

This commit is contained in:
Stephen Birarda 2015-05-05 11:00:27 -07:00
parent fa9dfbe073
commit c76ae56d64
4 changed files with 135 additions and 55 deletions

View file

@ -232,19 +232,12 @@ qint64 LimitedNodeList::readDatagram(QByteArray& incomingPacket, QHostAddress* a
qint64 LimitedNodeList::writeDatagram(const QByteArray& datagram, const HifiSockAddr& destinationSockAddr, qint64 LimitedNodeList::writeDatagram(const QByteArray& datagram, const HifiSockAddr& destinationSockAddr,
const QUuid& connectionSecret) { const QUuid& connectionSecret) {
QByteArray datagramCopy = datagram;
if (!connectionSecret.isNull()) {
// setup the MD5 hash for source verification in the header
replaceHashInPacketGivenConnectionUUID(datagramCopy, connectionSecret);
}
// XXX can BandwidthRecorder be used for this? // XXX can BandwidthRecorder be used for this?
// stat collection for packets // stat collection for packets
++_numCollectedPackets; ++_numCollectedPackets;
_numCollectedBytes += datagram.size(); _numCollectedBytes += datagram.size();
qint64 bytesWritten = _nodeSocket.writeDatagram(datagramCopy, qint64 bytesWritten = _nodeSocket.writeDatagram(datagram,
destinationSockAddr.getAddress(), destinationSockAddr.getPort()); destinationSockAddr.getAddress(), destinationSockAddr.getPort());
if (bytesWritten < 0) { if (bytesWritten < 0) {
@ -270,6 +263,19 @@ qint64 LimitedNodeList::writeDatagram(const QByteArray& datagram,
} }
} }
QByteArray datagramCopy = datagram;
PacketType packetType = packetTypeForPacket(datagramCopy);
// perform replacement of hash and optionally also sequence number in the header
if (SEQUENCE_NUMBERED_PACKETS.contains(packetType)) {
PacketSequenceNumber sequenceNumber = getNextSequenceNumberForPacket(destinationNode->getUUID(), packetType);
replaceHashAndSequenceNumberInPacketGivenType(datagramCopy, packetType,
destinationNode->getConnectionSecret(),
sequenceNumber);
} else {
replaceHashInPacketGivenType(datagramCopy, packetType, destinationNode->getConnectionSecret());
}
emit dataSent(destinationNode->getType(), datagram.size()); emit dataSent(destinationNode->getType(), datagram.size());
auto bytesWritten = writeDatagram(datagram, *destinationSockAddr, destinationNode->getConnectionSecret()); auto bytesWritten = writeDatagram(datagram, *destinationSockAddr, destinationNode->getConnectionSecret());
// Keep track of per-destination-node bandwidth // Keep track of per-destination-node bandwidth
@ -318,6 +324,15 @@ qint64 LimitedNodeList::writeUnverifiedDatagram(const char* data, qint64 size, c
return writeUnverifiedDatagram(QByteArray(data, size), destinationNode, overridenSockAddr); return writeUnverifiedDatagram(QByteArray(data, size), destinationNode, overridenSockAddr);
} }
PacketSequenceNumber LimitedNodeList::getNextSequenceNumberForPacket(const QUuid& nodeUUID, PacketType packetType) {
// Thanks to std::map and std::unordered_map this line either default constructs the
// PacketTypeSequenceMap and the PacketSequenceNumber or returns the existing value.
// We use the postfix increment so that the stored value is incremented and the next
// return gives the correct value.
return _packetSequenceNumbers[nodeUUID][packetType]++;
}
void LimitedNodeList::processNodeData(const HifiSockAddr& senderSockAddr, const QByteArray& packet) { void LimitedNodeList::processNodeData(const HifiSockAddr& senderSockAddr, const QByteArray& packet) {
// the node decided not to do anything with this packet // the node decided not to do anything with this packet
// if it comes from a known source we should keep that node alive // if it comes from a known source we should keep that node alive

View file

@ -14,7 +14,9 @@
#include <stdint.h> #include <stdint.h>
#include <iterator> #include <iterator>
#include <map>
#include <memory> #include <memory>
#include <unordered_map>
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> // not on windows, not needed for mac or windows #include <unistd.h> // not on windows, not needed for mac or windows
@ -34,6 +36,7 @@
#include "DomainHandler.h" #include "DomainHandler.h"
#include "Node.h" #include "Node.h"
#include "PacketHeaders.h"
#include "UUIDHasher.h" #include "UUIDHasher.h"
const int MAX_PACKET_SIZE = 1450; const int MAX_PACKET_SIZE = 1450;
@ -76,6 +79,8 @@ namespace PingType {
const PingType_t Symmetric = 3; const PingType_t Symmetric = 3;
} }
typedef std::map<PacketType, PacketSequenceNumber> PacketTypeSequenceMap;
class LimitedNodeList : public QObject, public Dependency { class LimitedNodeList : public QObject, public Dependency {
Q_OBJECT Q_OBJECT
SINGLETON_DEPENDENCY SINGLETON_DEPENDENCY
@ -224,9 +229,12 @@ protected:
LimitedNodeList(LimitedNodeList const&); // Don't implement, needed to avoid copies of singleton LimitedNodeList(LimitedNodeList const&); // Don't implement, needed to avoid copies of singleton
void operator=(LimitedNodeList const&); // Don't implement, needed to avoid copies of singleton void operator=(LimitedNodeList const&); // Don't implement, needed to avoid copies of singleton
qint64 writeDatagram(const QByteArray& datagram, const HifiSockAddr& destinationSockAddr, qint64 writeDatagram(const QByteArray& datagram,
const HifiSockAddr& destinationSockAddr,
const QUuid& connectionSecret); const QUuid& connectionSecret);
PacketSequenceNumber getNextSequenceNumberForPacket(const QUuid& nodeUUID, PacketType packetType);
void changeSocketBufferSizes(int numBytes); void changeSocketBufferSizes(int numBytes);
void handleNodeKill(const SharedNodePointer& node); void handleNodeKill(const SharedNodePointer& node);
@ -248,6 +256,8 @@ protected:
bool _thisNodeCanAdjustLocks; bool _thisNodeCanAdjustLocks;
bool _thisNodeCanRez; bool _thisNodeCanRez;
std::unordered_map<QUuid, PacketTypeSequenceMap, UUIDHasher> _packetSequenceNumbers;
template<typename IteratorLambda> template<typename IteratorLambda>
void eachNodeHashIterator(IteratorLambda functor) { void eachNodeHashIterator(IteratorLambda functor) {
QWriteLocker writeLock(&_nodeMutex); QWriteLocker writeLock(&_nodeMutex);

View file

@ -45,8 +45,8 @@ int packArithmeticallyCodedValue(int value, char* destination) {
} }
} }
PacketVersion versionForPacketType(PacketType type) { PacketVersion versionForPacketType(PacketType packetType) {
switch (type) { switch (packetType) {
case PacketTypeMicrophoneAudioNoEcho: case PacketTypeMicrophoneAudioNoEcho:
case PacketTypeMicrophoneAudioWithEcho: case PacketTypeMicrophoneAudioWithEcho:
return 2; return 2;
@ -86,8 +86,8 @@ PacketVersion versionForPacketType(PacketType type) {
#define PACKET_TYPE_NAME_LOOKUP(x) case x: return QString(#x); #define PACKET_TYPE_NAME_LOOKUP(x) case x: return QString(#x);
QString nameForPacketType(PacketType type) { QString nameForPacketType(PacketType packetType) {
switch (type) { switch (packetType) {
PACKET_TYPE_NAME_LOOKUP(PacketTypeUnknown); PACKET_TYPE_NAME_LOOKUP(PacketTypeUnknown);
PACKET_TYPE_NAME_LOOKUP(PacketTypeStunResponse); PACKET_TYPE_NAME_LOOKUP(PacketTypeStunResponse);
PACKET_TYPE_NAME_LOOKUP(PacketTypeDomainList); PACKET_TYPE_NAME_LOOKUP(PacketTypeDomainList);
@ -132,30 +132,30 @@ QString nameForPacketType(PacketType type) {
PACKET_TYPE_NAME_LOOKUP(PacketTypeUnverifiedPing); PACKET_TYPE_NAME_LOOKUP(PacketTypeUnverifiedPing);
PACKET_TYPE_NAME_LOOKUP(PacketTypeUnverifiedPingReply); PACKET_TYPE_NAME_LOOKUP(PacketTypeUnverifiedPingReply);
default: default:
return QString("Type: ") + QString::number((int)type); return QString("Type: ") + QString::number((int)packetType);
} }
return QString("unexpected"); return QString("unexpected");
} }
QByteArray byteArrayWithPopulatedHeader(PacketType type, const QUuid& connectionUUID) { QByteArray byteArrayWithPopulatedHeader(PacketType packetType, const QUuid& connectionUUID) {
QByteArray freshByteArray(MAX_PACKET_HEADER_BYTES, 0); QByteArray freshByteArray(MAX_PACKET_HEADER_BYTES, 0);
freshByteArray.resize(populatePacketHeader(freshByteArray, type, connectionUUID)); freshByteArray.resize(populatePacketHeader(freshByteArray, packetType, connectionUUID));
return freshByteArray; return freshByteArray;
} }
int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID) { int populatePacketHeader(QByteArray& packet, PacketType packetType, const QUuid& connectionUUID) {
if (packet.size() < numBytesForPacketHeaderGivenPacketType(type)) { if (packet.size() < numBytesForPacketHeaderGivenPacketType(packetType)) {
packet.resize(numBytesForPacketHeaderGivenPacketType(type)); packet.resize(numBytesForPacketHeaderGivenPacketType(packetType));
} }
return populatePacketHeader(packet.data(), type, connectionUUID); return populatePacketHeader(packet.data(), packetType, connectionUUID);
} }
int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID) { int populatePacketHeader(char* packet, PacketType packetType, const QUuid& connectionUUID) {
int numTypeBytes = packArithmeticallyCodedValue(type, packet); int numTypeBytes = packArithmeticallyCodedValue(packetType, packet);
packet[numTypeBytes] = versionForPacketType(type); packet[numTypeBytes] = versionForPacketType(packetType);
char* position = packet + numTypeBytes + sizeof(PacketVersion); char* position = packet + numTypeBytes + sizeof(PacketVersion);
@ -165,38 +165,46 @@ int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionU
memcpy(position, rfcUUID.constData(), NUM_BYTES_RFC4122_UUID); memcpy(position, rfcUUID.constData(), NUM_BYTES_RFC4122_UUID);
position += NUM_BYTES_RFC4122_UUID; position += NUM_BYTES_RFC4122_UUID;
if (!NON_VERIFIED_PACKETS.contains(type)) { if (!NON_VERIFIED_PACKETS.contains(packetType)) {
// pack 16 bytes of zeros where the md5 hash will be placed once data is packed // pack 16 bytes of zeros where the md5 hash will be placed once data is packed
memset(position, 0, NUM_BYTES_MD5_HASH); memset(position, 0, NUM_BYTES_MD5_HASH);
position += NUM_BYTES_MD5_HASH; position += NUM_BYTES_MD5_HASH;
} }
if (!SEQUENCE_NUMBERED_PACKETS.contains(packetType)) {
// Pack zeros for the number of bytes that the sequence number requires.
// The LimitedNodeList will handle packing in the sequence number when sending out the packet.
memset(position, 0, sizeof(PacketSequenceNumber));
position += sizeof(PacketSequenceNumber);
}
// return the number of bytes written for pointer pushing // return the number of bytes written for pointer pushing
return position - packet; return position - packet;
} }
int numBytesForPacketHeader(const QByteArray& packet) { int numBytesForPacketHeader(const QByteArray& packet) {
// returns the number of bytes used for the type, version, and UUID PacketType packetType = packetTypeForPacket(packet);
return numBytesArithmeticCodingFromBuffer(packet.data()) return numBytesForPacketHeaderGivenPacketType(packetType);
+ numHashBytesInPacketHeaderGivenPacketType(packetTypeForPacket(packet))
+ NUM_STATIC_HEADER_BYTES;
} }
int numBytesForPacketHeader(const char* packet) { int numBytesForPacketHeader(const char* packet) {
// returns the number of bytes used for the type, version, and UUID PacketType packetType = packetTypeForPacket(packet);
return numBytesArithmeticCodingFromBuffer(packet) return numBytesForPacketHeaderGivenPacketType(packetType);
+ numHashBytesInPacketHeaderGivenPacketType(packetTypeForPacket(packet)) }
int numBytesForPacketHeaderGivenPacketType(PacketType packetType) {
return (int) ceilf((float) packetType / 255)
+ numHashBytesForType(packetType)
+ numSequenceNumberBytesForType(packetType)
+ NUM_STATIC_HEADER_BYTES; + NUM_STATIC_HEADER_BYTES;
} }
int numBytesForPacketHeaderGivenPacketType(PacketType type) { int numHashBytesForType(PacketType packetType) {
return (int) ceilf((float)type / 255) return (NON_VERIFIED_PACKETS.contains(packetType) ? 0 : NUM_BYTES_MD5_HASH);
+ numHashBytesInPacketHeaderGivenPacketType(type)
+ NUM_STATIC_HEADER_BYTES;
} }
int numHashBytesInPacketHeaderGivenPacketType(PacketType type) { int numSequenceNumberBytesForType(PacketType packetType) {
return (NON_VERIFIED_PACKETS.contains(type) ? 0 : NUM_BYTES_MD5_HASH); return (SEQUENCE_NUMBERED_PACKETS.contains(packetType) ? sizeof(PacketSequenceNumber) : 0);
} }
QUuid uuidFromPacketHeader(const QByteArray& packet) { QUuid uuidFromPacketHeader(const QByteArray& packet) {
@ -204,8 +212,18 @@ QUuid uuidFromPacketHeader(const QByteArray& packet) {
NUM_BYTES_RFC4122_UUID)); NUM_BYTES_RFC4122_UUID));
} }
int hashOffsetForPacketType(PacketType packetType) {
return numBytesForPacketHeaderGivenPacketType(packetType)
- (SEQUENCE_NUMBERED_PACKETS.contains(packetType) ? sizeof(PacketSequenceNumber) : 0)
- NUM_BYTES_RFC4122_UUID;
}
int sequenceNumberOffsetForPacketType(PacketType packetType) {
return numBytesForPacketHeaderGivenPacketType(packetType) - sizeof(PacketSequenceNumber);
}
QByteArray hashFromPacketHeader(const QByteArray& packet) { QByteArray hashFromPacketHeader(const QByteArray& packet) {
return packet.mid(numBytesForPacketHeader(packet) - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH); return packet.mid(hashOffsetForPacketType(packetTypeForPacket(packet)), NUM_BYTES_MD5_HASH);
} }
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID) { QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID) {
@ -213,9 +231,25 @@ QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid&
QCryptographicHash::Md5); QCryptographicHash::Md5);
} }
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID) { void replaceHashInPacketGivenType(QByteArray& packet, PacketType packetType, const QUuid& connectionUUID) {
packet.replace(numBytesForPacketHeader(packet) - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH, packet.replace(hashOffsetForPacketType(packetType), NUM_BYTES_MD5_HASH,
hashForPacketAndConnectionUUID(packet, connectionUUID)); hashForPacketAndConnectionUUID(packet, connectionUUID));
}
void replaceSequenceNumberInPacketGivenType(QByteArray& packet, PacketType packetType, PacketSequenceNumber sequenceNumber) {
packet.replace(sequenceNumberOffsetForPacketType(packetType),
sizeof(PacketTypeSequenceMap), reinterpret_cast<char*>(&sequenceNumber));
}
void replaceHashAndSequenceNumberInPacketGivenType(QByteArray& packet, PacketType packetType,
const QUuid& connectionUUID, PacketSequenceNumber sequenceNumber) {
replaceHashInPacketGivenType(packet, packetType, connectionUUID);
replaceSequenceNumberInPacketGivenType(packet, packetType, sequenceNumber);
}
void replaceHashAndSequenceNumberInPacket(QByteArray& packet, const QUuid& connectionUUID, PacketSequenceNumber sequenceNumber) {
replaceHashAndSequenceNumberInPacketGivenType(packet, packetTypeForPacket(packet), connectionUUID, sequenceNumber);
} }
PacketType packetTypeForPacket(const QByteArray& packet) { PacketType packetTypeForPacket(const QByteArray& packet) {

View file

@ -12,14 +12,16 @@
#ifndef hifi_PacketHeaders_h #ifndef hifi_PacketHeaders_h
#define hifi_PacketHeaders_h #define hifi_PacketHeaders_h
#include <cstdint>
#include <QtCore/QCryptographicHash> #include <QtCore/QCryptographicHash>
#include <QtCore/QSet> #include <QtCore/QSet>
#include <QtCore/QUuid> #include <QtCore/QUuid>
#include "UUID.h" #include "UUID.h"
// NOTE: if adding a new packet type, you can replace one marked usable or add at the end // NOTE: if adding a new packet packetType, you can replace one marked usable or add at the end
// NOTE: if you want the name of the packet type to be available for debugging or logging, update nameForPacketType() as well // NOTE: if you want the name of the packet packetType to be available for debugging or logging, update nameForPacketType() as well
enum PacketType { enum PacketType {
PacketTypeUnknown, // 0 PacketTypeUnknown, // 0
PacketTypeStunResponse, PacketTypeStunResponse,
@ -78,6 +80,7 @@ enum PacketType {
}; };
typedef char PacketVersion; typedef char PacketVersion;
typedef uint16_t PacketSequenceNumber;
const QSet<PacketType> NON_VERIFIED_PACKETS = QSet<PacketType>() const QSet<PacketType> NON_VERIFIED_PACKETS = QSet<PacketType>()
<< PacketTypeDomainServerRequireDTLS << PacketTypeDomainConnectRequest << PacketTypeDomainServerRequireDTLS << PacketTypeDomainConnectRequest
@ -88,33 +91,51 @@ const QSet<PacketType> NON_VERIFIED_PACKETS = QSet<PacketType>()
<< PacketTypeIceServerHeartbeat << PacketTypeIceServerHeartbeatResponse << PacketTypeIceServerHeartbeat << PacketTypeIceServerHeartbeatResponse
<< PacketTypeUnverifiedPing << PacketTypeUnverifiedPingReply << PacketTypeStopNode; << PacketTypeUnverifiedPing << PacketTypeUnverifiedPingReply << PacketTypeStopNode;
const QSet<PacketType> SEQUENCE_NUMBERED_PACKETS = QSet<PacketType>()
<< PacketTypeAvatarData;
const int NUM_BYTES_MD5_HASH = 16; const int NUM_BYTES_MD5_HASH = 16;
const int NUM_STATIC_HEADER_BYTES = sizeof(PacketVersion) + NUM_BYTES_RFC4122_UUID; const int NUM_STATIC_HEADER_BYTES = sizeof(PacketVersion) + NUM_BYTES_RFC4122_UUID;
const int MAX_PACKET_HEADER_BYTES = sizeof(PacketType) + NUM_BYTES_MD5_HASH + NUM_STATIC_HEADER_BYTES; const int MAX_PACKET_HEADER_BYTES = sizeof(PacketType) + NUM_BYTES_MD5_HASH + NUM_STATIC_HEADER_BYTES;
PacketVersion versionForPacketType(PacketType type); PacketType packetTypeForPacket(const QByteArray& packet);
QString nameForPacketType(PacketType type); PacketType packetTypeForPacket(const char* packet);
PacketVersion versionForPacketType(PacketType packetType);
QString nameForPacketType(PacketType packetType);
const QUuid nullUUID = QUuid(); const QUuid nullUUID = QUuid();
QByteArray byteArrayWithPopulatedHeader(PacketType type, const QUuid& connectionUUID = nullUUID); QByteArray byteArrayWithPopulatedHeader(PacketType packetType, const QUuid& connectionUUID = nullUUID);
int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID = nullUUID); int populatePacketHeader(QByteArray& packet, PacketType packetType, const QUuid& connectionUUID = nullUUID);
int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID = nullUUID); int populatePacketHeader(char* packet, PacketType packetType, const QUuid& connectionUUID = nullUUID);
int numHashBytesInPacketHeaderGivenPacketType(PacketType type); int numHashBytesForType(PacketType packetType);
int numSequenceNumberBytesForType(PacketType packetType);
int numBytesForPacketHeader(const QByteArray& packet); int numBytesForPacketHeader(const QByteArray& packet);
int numBytesForPacketHeader(const char* packet); int numBytesForPacketHeader(const char* packet);
int numBytesForPacketHeaderGivenPacketType(PacketType type); int numBytesForPacketHeaderGivenPacketType(PacketType packetType);
QUuid uuidFromPacketHeader(const QByteArray& packet); QUuid uuidFromPacketHeader(const QByteArray& packet);
int hashOffsetForPacketType(PacketType packetType);
int sequenceNumberOffsetForPacketType(PacketType packetType);
QByteArray hashFromPacketHeader(const QByteArray& packet); QByteArray hashFromPacketHeader(const QByteArray& packet);
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID); QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID);
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID);
PacketType packetTypeForPacket(const QByteArray& packet); void replaceHashInPacketGivenType(QByteArray& packet, PacketType packetType, const QUuid& connectionUUID);
PacketType packetTypeForPacket(const char* packet); void replaceHashInPacket(QByteArray& packet, const QUuid& connectionUUID)
{ replaceHashInPacketGivenType(packet, packetTypeForPacket(packet), connectionUUID); }
void replaceSequenceNumberInPacketGivenType(QByteArray& packet, PacketType packetType, PacketSequenceNumber sequenceNumber);
void replaceSequenceNumberInPacket(QByteArray& packet, PacketSequenceNumber sequenceNumber)
{ replaceSequenceNumberInPacketGivenType(packet, packetTypeForPacket(packet), sequenceNumber); }
void replaceHashAndSequenceNumberInPacketGivenType(QByteArray& packet, PacketType packetType,
const QUuid& connectionUUID, PacketSequenceNumber sequenceNumber);
void replaceHashAndSequenceNumberInPacket(QByteArray& packet, const QUuid& connectionUUID, PacketSequenceNumber sequenceNumber);
int arithmeticCodingValueFromBuffer(const char* checkValue); int arithmeticCodingValueFromBuffer(const char* checkValue);
int numBytesArithmeticCodingFromBuffer(const char* checkValue); int numBytesArithmeticCodingFromBuffer(const char* checkValue);