enforce both a version and hash match for packets

This commit is contained in:
Stephen Birarda 2014-02-07 11:10:38 -08:00
parent 997bea708d
commit 66d4eeb805
10 changed files with 86 additions and 50 deletions

View file

@ -830,7 +830,7 @@ void AnimationServer::readPendingDatagrams() {
receivedPacket.resize(nodeList->getNodeSocket().pendingDatagramSize());
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
nodeSockAddr.getAddressPointer(), nodeSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) {
if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
if (packetTypeForPacket(receivedPacket) == PacketTypeJurisdiction) {
int headerBytes = numBytesForPacketHeader(receivedPacket);
// PacketType_JURISDICTION, first byte is the node type...

View file

@ -111,7 +111,7 @@ void AssignmentClient::readPendingDatagrams() {
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) {
if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
if (_currentAssignment) {
// have the threaded current assignment handle this datagram
QMetaObject::invokeMethod(_currentAssignment, "processDatagram", Qt::QueuedConnection,

View file

@ -66,7 +66,7 @@ void DataServer::readPendingDatagrams() {
PacketType requestType = packetTypeForPacket(receivedPacket);
if ((requestType == PacketTypeDataServerPut || requestType == PacketTypeDataServerGet) &&
packetVersionMatch(receivedPacket)) {
receivedPacket[numBytesArithmeticCodingFromBuffer(receivedPacket.data())] == versionForPacketType(requestType)) {
QDataStream packetStream(receivedPacket);
int numReceivedHeaderBytes = numBytesForPacketHeader(receivedPacket);

View file

@ -236,7 +236,7 @@ void DomainServer::readAvailableDatagrams() {
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) {
if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
PacketType requestType = packetTypeForPacket(receivedPacket);
if (requestType == PacketTypeDomainListRequest) {

View file

@ -40,7 +40,7 @@ void DatagramProcessor::processDatagrams() {
_packetCount++;
_byteCount += incomingPacket.size();
if (packetVersionMatch(incomingPacket)) {
if (nodeList->packetVersionAndHashMatch(incomingPacket)) {
// only process this packet if we have a match on the packet version
switch (packetTypeForPacket(incomingPacket)) {
case PacketTypeTransmitterData:

View file

@ -48,7 +48,7 @@ void VoxelPacketProcessor::processPacket(const SharedNodePointer& sendingNode, c
wasStatsPacket = true;
if (messageLength > statsMessageLength) {
mutablePacket = mutablePacket.mid(statsMessageLength);
if (!packetVersionMatch(packet)) {
if (!NodeList::getInstance()->packetVersionAndHashMatch(packet)) {
return; // bail since piggyback data doesn't match our versioning
}
} else {

View file

@ -80,30 +80,64 @@ NodeList::~NodeList() {
clear();
}
qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode,
const HifiSockAddr& overridenSockAddr) {
bool NodeList::packetVersionAndHashMatch(const QByteArray& packet) {
// currently this just checks if the version in the packet matches our return from versionForPacketType
// may need to be expanded in the future for types and versions that take > than 1 byte
// setup the MD5 hash for source verification in the header
int numBytesPacketHeader = numBytesForPacketHeader(datagram);
QByteArray dataSecretHash = QCryptographicHash::hash(datagram.mid(numBytesPacketHeader)
+ destinationNode->getConnectionSecret().toRfc4122(),
QCryptographicHash::Md5);
QByteArray datagramWithHash = datagram;
datagramWithHash.replace(numBytesPacketHeader - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH, dataSecretHash);
if (packet[1] != versionForPacketType(packetTypeForPacket(packet))
&& packetTypeForPacket(packet) != PacketTypeStunResponse) {
PacketType mismatchType = packetTypeForPacket(packet);
int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data());
qDebug() << "Packet version mismatch on" << packetTypeForPacket(packet) << "- Sender"
<< uuidFromPacketHeader(packet) << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but"
<< qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected.";
}
// if we don't have an ovveriden address, assume they want to send to the node's active socket
const HifiSockAddr* destinationSockAddr = &overridenSockAddr;
if (overridenSockAddr.isNull()) {
if (getNodeActiveSocketOrPing(destinationNode)) {
// use the node's active socket as the destination socket
destinationSockAddr = destinationNode->getActiveSocket();
if (packetTypeForPacket(packet) != PacketTypeDomainList && packetTypeForPacket(packet) != PacketTypeDomainListRequest) {
// figure out which node this is from
SharedNodePointer sendingNode = sendingNodeForPacket(packet);
if (sendingNode) {
// check if the md5 hash in the header matches the hash we would expect
if (hashFromPacketHeader(packet) == hashForPacketAndConnectionUUID(packet, sendingNode->getConnectionSecret())) {
return true;
} else {
qDebug() << "Packet hash mismatch" << packetTypeForPacket(packet) << "received from known node with UUID"
<< uuidFromPacketHeader(packet);
}
} else {
// we don't have a socket to send to, return 0
return 0;
qDebug() << "Packet of type" << packetTypeForPacket(packet) << "received from unknown node with UUID"
<< uuidFromPacketHeader(packet);
}
}
return _nodeSocket.writeDatagram(datagramWithHash, destinationSockAddr->getAddress(), destinationSockAddr->getPort());
return false;
}
qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode,
const HifiSockAddr& overridenSockAddr) {
if (destinationNode) {
// if we don't have an ovveriden address, assume they want to send to the node's active socket
const HifiSockAddr* destinationSockAddr = &overridenSockAddr;
if (overridenSockAddr.isNull()) {
if (getNodeActiveSocketOrPing(destinationNode)) {
// use the node's active socket as the destination socket
destinationSockAddr = destinationNode->getActiveSocket();
} else {
// we don't have a socket to send to, return 0
return 0;
}
}
QByteArray datagramCopy = datagram;
// setup the MD5 hash for source verification in the header
replaceHashInPacketGivenConnectionUUID(datagramCopy, destinationNode->getConnectionSecret());
return _nodeSocket.writeDatagram(datagramCopy, destinationSockAddr->getAddress(), destinationSockAddr->getPort());
}
// didn't have a destinationNode to send to, return 0
return 0;
}
qint64 NodeList::writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode,
@ -191,8 +225,11 @@ void NodeList::processNodeData(const HifiSockAddr& senderSockAddr, const QByteAr
}
case PacketTypePing: {
// send back a reply
QByteArray replyPacket = constructPingReplyPacket(packet);
writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr);
if (sendingNodeForPacket(packet)) {
QByteArray replyPacket = constructPingReplyPacket(packet);
writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr);
}
break;
}
case PacketTypePingReply: {
@ -716,9 +753,9 @@ void NodeList::activateSocketFromNodeCommunication(const QByteArray& packet, con
// if this is a local or public ping then we can activate a socket
// we do nothing with agnostic pings, those are simply for timing
if (pingType == PingType::Local) {
if (pingType == PingType::Local && sendingNode->getActiveSocket() != &sendingNode->getLocalSocket()) {
sendingNode->activateLocalSocket();
} else if (pingType == PingType::Public) {
} else if (pingType == PingType::Public && !sendingNode->getActiveSocket()) {
sendingNode->activatePublicSocket();
}
}

View file

@ -84,6 +84,8 @@ public:
QUdpSocket& getNodeSocket() { return _nodeSocket; }
bool packetVersionAndHashMatch(const QByteArray& packet);
qint64 writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode,
const HifiSockAddr& overridenSockAddr = HifiSockAddr());
qint64 writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode,

View file

@ -85,26 +85,6 @@ int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionU
return position - packet;
}
bool packetVersionMatch(const QByteArray& packet) {
// currently this just checks if the version in the packet matches our return from versionForPacketType
// may need to be expanded in the future for types and versions that take > than 1 byte
if (packet[1] == versionForPacketType(packetTypeForPacket(packet)) || packetTypeForPacket(packet) == PacketTypeStunResponse) {
return true;
} else {
PacketType mismatchType = packetTypeForPacket(packet);
int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data());
QUuid nodeUUID = uuidFromPacketHeader(packet);
qDebug() << "Packet mismatch on" << packetTypeForPacket(packet) << "- Sender"
<< nodeUUID << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but"
<< qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected.";
return false;
}
}
int numBytesForPacketHeader(const QByteArray& packet) {
// returns the number of bytes used for the type, version, and UUID
return numBytesArithmeticCodingFromBuffer(packet.data()) + NUM_STATIC_HEADER_BYTES;
@ -124,6 +104,20 @@ QUuid uuidFromPacketHeader(const QByteArray& packet) {
NUM_BYTES_RFC4122_UUID));
}
QByteArray hashFromPacketHeader(const QByteArray& packet) {
return packet.mid(NUM_STATIC_HEADER_BYTES - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH);
}
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID) {
return QCryptographicHash::hash(packet.mid(numBytesForPacketHeader(packet))
+ connectionUUID.toRfc4122(), QCryptographicHash::Md5);
}
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID) {
packet.replace(numBytesForPacketHeader(packet) - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH,
hashForPacketAndConnectionUUID(packet, connectionUUID));
}
PacketType packetTypeForPacket(const QByteArray& packet) {
return (PacketType) arithmeticCodingValueFromBuffer(packet.data());
}

View file

@ -70,17 +70,20 @@ QByteArray byteArrayWithPopluatedHeader(PacketType type, const QUuid& connection
int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID = nullUUID);
int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID = nullUUID);
bool packetVersionMatch(const QByteArray& packet);
int numBytesForPacketHeader(const QByteArray& packet);
int numBytesForPacketHeader(const char* packet);
int numBytesForPacketHeaderGivenPacketType(PacketType type);
QUuid uuidFromPacketHeader(const QByteArray& packet);
QByteArray hashFromPacketHeader(const QByteArray& packet);
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID);
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID);
PacketType packetTypeForPacket(const QByteArray& packet);
PacketType packetTypeForPacket(const char* packet);
int arithmeticCodingValueFromBuffer(const char* checkValue);
int numBytesArithmeticCodingFromBuffer(const char* checkValue);
#endif