enforce both a version and hash match for packets

This commit is contained in:
Stephen Birarda 2014-02-07 11:10:38 -08:00
parent 997bea708d
commit 66d4eeb805
10 changed files with 86 additions and 50 deletions

View file

@ -830,7 +830,7 @@ void AnimationServer::readPendingDatagrams() {
receivedPacket.resize(nodeList->getNodeSocket().pendingDatagramSize()); receivedPacket.resize(nodeList->getNodeSocket().pendingDatagramSize());
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
nodeSockAddr.getAddressPointer(), nodeSockAddr.getPortPointer()); nodeSockAddr.getAddressPointer(), nodeSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) { if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
if (packetTypeForPacket(receivedPacket) == PacketTypeJurisdiction) { if (packetTypeForPacket(receivedPacket) == PacketTypeJurisdiction) {
int headerBytes = numBytesForPacketHeader(receivedPacket); int headerBytes = numBytesForPacketHeader(receivedPacket);
// PacketType_JURISDICTION, first byte is the node type... // PacketType_JURISDICTION, first byte is the node type...

View file

@ -111,7 +111,7 @@ void AssignmentClient::readPendingDatagrams() {
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer()); senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) { if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
if (_currentAssignment) { if (_currentAssignment) {
// have the threaded current assignment handle this datagram // have the threaded current assignment handle this datagram
QMetaObject::invokeMethod(_currentAssignment, "processDatagram", Qt::QueuedConnection, QMetaObject::invokeMethod(_currentAssignment, "processDatagram", Qt::QueuedConnection,

View file

@ -66,7 +66,7 @@ void DataServer::readPendingDatagrams() {
PacketType requestType = packetTypeForPacket(receivedPacket); PacketType requestType = packetTypeForPacket(receivedPacket);
if ((requestType == PacketTypeDataServerPut || requestType == PacketTypeDataServerGet) && if ((requestType == PacketTypeDataServerPut || requestType == PacketTypeDataServerGet) &&
packetVersionMatch(receivedPacket)) { receivedPacket[numBytesArithmeticCodingFromBuffer(receivedPacket.data())] == versionForPacketType(requestType)) {
QDataStream packetStream(receivedPacket); QDataStream packetStream(receivedPacket);
int numReceivedHeaderBytes = numBytesForPacketHeader(receivedPacket); int numReceivedHeaderBytes = numBytesForPacketHeader(receivedPacket);

View file

@ -236,7 +236,7 @@ void DomainServer::readAvailableDatagrams() {
nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(), nodeList->getNodeSocket().readDatagram(receivedPacket.data(), receivedPacket.size(),
senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer()); senderSockAddr.getAddressPointer(), senderSockAddr.getPortPointer());
if (packetVersionMatch(receivedPacket)) { if (nodeList->packetVersionAndHashMatch(receivedPacket)) {
PacketType requestType = packetTypeForPacket(receivedPacket); PacketType requestType = packetTypeForPacket(receivedPacket);
if (requestType == PacketTypeDomainListRequest) { if (requestType == PacketTypeDomainListRequest) {

View file

@ -40,7 +40,7 @@ void DatagramProcessor::processDatagrams() {
_packetCount++; _packetCount++;
_byteCount += incomingPacket.size(); _byteCount += incomingPacket.size();
if (packetVersionMatch(incomingPacket)) { if (nodeList->packetVersionAndHashMatch(incomingPacket)) {
// only process this packet if we have a match on the packet version // only process this packet if we have a match on the packet version
switch (packetTypeForPacket(incomingPacket)) { switch (packetTypeForPacket(incomingPacket)) {
case PacketTypeTransmitterData: case PacketTypeTransmitterData:

View file

@ -48,7 +48,7 @@ void VoxelPacketProcessor::processPacket(const SharedNodePointer& sendingNode, c
wasStatsPacket = true; wasStatsPacket = true;
if (messageLength > statsMessageLength) { if (messageLength > statsMessageLength) {
mutablePacket = mutablePacket.mid(statsMessageLength); mutablePacket = mutablePacket.mid(statsMessageLength);
if (!packetVersionMatch(packet)) { if (!NodeList::getInstance()->packetVersionAndHashMatch(packet)) {
return; // bail since piggyback data doesn't match our versioning return; // bail since piggyback data doesn't match our versioning
} }
} else { } else {

View file

@ -80,17 +80,43 @@ NodeList::~NodeList() {
clear(); clear();
} }
bool NodeList::packetVersionAndHashMatch(const QByteArray& packet) {
// currently this just checks if the version in the packet matches our return from versionForPacketType
// may need to be expanded in the future for types and versions that take > than 1 byte
if (packet[1] != versionForPacketType(packetTypeForPacket(packet))
&& packetTypeForPacket(packet) != PacketTypeStunResponse) {
PacketType mismatchType = packetTypeForPacket(packet);
int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data());
qDebug() << "Packet version mismatch on" << packetTypeForPacket(packet) << "- Sender"
<< uuidFromPacketHeader(packet) << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but"
<< qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected.";
}
if (packetTypeForPacket(packet) != PacketTypeDomainList && packetTypeForPacket(packet) != PacketTypeDomainListRequest) {
// figure out which node this is from
SharedNodePointer sendingNode = sendingNodeForPacket(packet);
if (sendingNode) {
// check if the md5 hash in the header matches the hash we would expect
if (hashFromPacketHeader(packet) == hashForPacketAndConnectionUUID(packet, sendingNode->getConnectionSecret())) {
return true;
} else {
qDebug() << "Packet hash mismatch" << packetTypeForPacket(packet) << "received from known node with UUID"
<< uuidFromPacketHeader(packet);
}
} else {
qDebug() << "Packet of type" << packetTypeForPacket(packet) << "received from unknown node with UUID"
<< uuidFromPacketHeader(packet);
}
}
return false;
}
qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode, qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode,
const HifiSockAddr& overridenSockAddr) { const HifiSockAddr& overridenSockAddr) {
if (destinationNode) {
// setup the MD5 hash for source verification in the header
int numBytesPacketHeader = numBytesForPacketHeader(datagram);
QByteArray dataSecretHash = QCryptographicHash::hash(datagram.mid(numBytesPacketHeader)
+ destinationNode->getConnectionSecret().toRfc4122(),
QCryptographicHash::Md5);
QByteArray datagramWithHash = datagram;
datagramWithHash.replace(numBytesPacketHeader - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH, dataSecretHash);
// if we don't have an ovveriden address, assume they want to send to the node's active socket // if we don't have an ovveriden address, assume they want to send to the node's active socket
const HifiSockAddr* destinationSockAddr = &overridenSockAddr; const HifiSockAddr* destinationSockAddr = &overridenSockAddr;
if (overridenSockAddr.isNull()) { if (overridenSockAddr.isNull()) {
@ -103,7 +129,15 @@ qint64 NodeList::writeDatagram(const QByteArray& datagram, const SharedNodePoint
} }
} }
return _nodeSocket.writeDatagram(datagramWithHash, destinationSockAddr->getAddress(), destinationSockAddr->getPort()); QByteArray datagramCopy = datagram;
// setup the MD5 hash for source verification in the header
replaceHashInPacketGivenConnectionUUID(datagramCopy, destinationNode->getConnectionSecret());
return _nodeSocket.writeDatagram(datagramCopy, destinationSockAddr->getAddress(), destinationSockAddr->getPort());
}
// didn't have a destinationNode to send to, return 0
return 0;
} }
qint64 NodeList::writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode, qint64 NodeList::writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode,
@ -191,8 +225,11 @@ void NodeList::processNodeData(const HifiSockAddr& senderSockAddr, const QByteAr
} }
case PacketTypePing: { case PacketTypePing: {
// send back a reply // send back a reply
if (sendingNodeForPacket(packet)) {
QByteArray replyPacket = constructPingReplyPacket(packet); QByteArray replyPacket = constructPingReplyPacket(packet);
writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr); writeDatagram(replyPacket, sendingNodeForPacket(packet), senderSockAddr);
}
break; break;
} }
case PacketTypePingReply: { case PacketTypePingReply: {
@ -716,9 +753,9 @@ void NodeList::activateSocketFromNodeCommunication(const QByteArray& packet, con
// if this is a local or public ping then we can activate a socket // if this is a local or public ping then we can activate a socket
// we do nothing with agnostic pings, those are simply for timing // we do nothing with agnostic pings, those are simply for timing
if (pingType == PingType::Local) { if (pingType == PingType::Local && sendingNode->getActiveSocket() != &sendingNode->getLocalSocket()) {
sendingNode->activateLocalSocket(); sendingNode->activateLocalSocket();
} else if (pingType == PingType::Public) { } else if (pingType == PingType::Public && !sendingNode->getActiveSocket()) {
sendingNode->activatePublicSocket(); sendingNode->activatePublicSocket();
} }
} }

View file

@ -84,6 +84,8 @@ public:
QUdpSocket& getNodeSocket() { return _nodeSocket; } QUdpSocket& getNodeSocket() { return _nodeSocket; }
bool packetVersionAndHashMatch(const QByteArray& packet);
qint64 writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode, qint64 writeDatagram(const QByteArray& datagram, const SharedNodePointer& destinationNode,
const HifiSockAddr& overridenSockAddr = HifiSockAddr()); const HifiSockAddr& overridenSockAddr = HifiSockAddr());
qint64 writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode, qint64 writeDatagram(const char* data, qint64 size, const SharedNodePointer& destinationNode,

View file

@ -85,26 +85,6 @@ int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionU
return position - packet; return position - packet;
} }
bool packetVersionMatch(const QByteArray& packet) {
// currently this just checks if the version in the packet matches our return from versionForPacketType
// may need to be expanded in the future for types and versions that take > than 1 byte
if (packet[1] == versionForPacketType(packetTypeForPacket(packet)) || packetTypeForPacket(packet) == PacketTypeStunResponse) {
return true;
} else {
PacketType mismatchType = packetTypeForPacket(packet);
int numPacketTypeBytes = arithmeticCodingValueFromBuffer(packet.data());
QUuid nodeUUID = uuidFromPacketHeader(packet);
qDebug() << "Packet mismatch on" << packetTypeForPacket(packet) << "- Sender"
<< nodeUUID << "sent" << qPrintable(QString::number(packet[numPacketTypeBytes])) << "but"
<< qPrintable(QString::number(versionForPacketType(mismatchType))) << "expected.";
return false;
}
}
int numBytesForPacketHeader(const QByteArray& packet) { int numBytesForPacketHeader(const QByteArray& packet) {
// returns the number of bytes used for the type, version, and UUID // returns the number of bytes used for the type, version, and UUID
return numBytesArithmeticCodingFromBuffer(packet.data()) + NUM_STATIC_HEADER_BYTES; return numBytesArithmeticCodingFromBuffer(packet.data()) + NUM_STATIC_HEADER_BYTES;
@ -124,6 +104,20 @@ QUuid uuidFromPacketHeader(const QByteArray& packet) {
NUM_BYTES_RFC4122_UUID)); NUM_BYTES_RFC4122_UUID));
} }
QByteArray hashFromPacketHeader(const QByteArray& packet) {
return packet.mid(NUM_STATIC_HEADER_BYTES - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH);
}
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID) {
return QCryptographicHash::hash(packet.mid(numBytesForPacketHeader(packet))
+ connectionUUID.toRfc4122(), QCryptographicHash::Md5);
}
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID) {
packet.replace(numBytesForPacketHeader(packet) - NUM_BYTES_MD5_HASH, NUM_BYTES_MD5_HASH,
hashForPacketAndConnectionUUID(packet, connectionUUID));
}
PacketType packetTypeForPacket(const QByteArray& packet) { PacketType packetTypeForPacket(const QByteArray& packet) {
return (PacketType) arithmeticCodingValueFromBuffer(packet.data()); return (PacketType) arithmeticCodingValueFromBuffer(packet.data());
} }

View file

@ -70,17 +70,20 @@ QByteArray byteArrayWithPopluatedHeader(PacketType type, const QUuid& connection
int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID = nullUUID); int populatePacketHeader(QByteArray& packet, PacketType type, const QUuid& connectionUUID = nullUUID);
int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID = nullUUID); int populatePacketHeader(char* packet, PacketType type, const QUuid& connectionUUID = nullUUID);
bool packetVersionMatch(const QByteArray& packet);
int numBytesForPacketHeader(const QByteArray& packet); int numBytesForPacketHeader(const QByteArray& packet);
int numBytesForPacketHeader(const char* packet); int numBytesForPacketHeader(const char* packet);
int numBytesForPacketHeaderGivenPacketType(PacketType type); int numBytesForPacketHeaderGivenPacketType(PacketType type);
QUuid uuidFromPacketHeader(const QByteArray& packet); QUuid uuidFromPacketHeader(const QByteArray& packet);
QByteArray hashFromPacketHeader(const QByteArray& packet);
QByteArray hashForPacketAndConnectionUUID(const QByteArray& packet, const QUuid& connectionUUID);
void replaceHashInPacketGivenConnectionUUID(QByteArray& packet, const QUuid& connectionUUID);
PacketType packetTypeForPacket(const QByteArray& packet); PacketType packetTypeForPacket(const QByteArray& packet);
PacketType packetTypeForPacket(const char* packet); PacketType packetTypeForPacket(const char* packet);
int arithmeticCodingValueFromBuffer(const char* checkValue); int arithmeticCodingValueFromBuffer(const char* checkValue);
int numBytesArithmeticCodingFromBuffer(const char* checkValue);
#endif #endif