Point nearest-neighbour search class

A common task in game code is finding which of a stored set of points is nearest to another arbitrary point. For example, you might want to find which waypoint or monster generator is nearest to the player’s current position. However, the search is time-consuming when there are a lot of stored points and the search is just looking through all of them to find the nearest.

I’ve had a go at implementing a kd-tree, a data structure that speeds up the nearest-neighbour search dramatically (see the attached script file). It is currently a very basic implementation, but I’d be happy to hear suggestions for improvements of functionality or efficiency.

194882–6950–$kdtree_166.cs (5.34 KB)

6 Likes

Hey Andeeee,

Thanks for posting this code. I’d like to use it to learn more about KDTrees as well as Unity.

Could you give me any basic instructions how to implement your code? I tried dropping the script on a empty GO but I just get the error message - Can’t add script behaviour KDTree. The script needs to derive from MonoBehaviour!

Any advice?

Thanks.

Mitch

Hi Andeeee,

I was wondering if your script could be used to find the nearest player to the current position of the ball as it moves around the field? I thought it might be used it to auto select the closest player?

iByte

I made some progress - I don’t know c# at all - I barely know unityscript, but I did manage to convert the KDTRee from c# to unityscript. It runs without errors and does seem to compute a KDTree from an array of points.

But when I perform a FindNearest call on a point, it always returns an index of -1. It never finds the nearest point.

I’m sure I screwed up the code somewhat when I took a sledgehammer to it and converted to unityscript. Here’s the code - if anyone has any advice for me I’d appreciate it.

I put both of these scripts on and empty GO:

KDTree.js

var lr : KDTree[];
var pivot : Vector3;
var pivotIndex : int;
var axis : int;
	
//	Change this value to 2 if you only need two-dimensional X,Y points. The search will
//	be quicker in two dimensions.
var numDims : int = 3;

function KDTree()
{
	lr = new KDTree[2];
}

//	Make a new tree from a list of points.
function MakeFromPoints(points : Vector3[]) : KDTree 
{
	var indices : int[] = Iota(points.length);
	return MakeFromPointsInner(0, 0, points.length - 1, points, indices);
}

//	Recursively build a tree by separating points at plane boundaries.
function MakeFromPointsInner(depth : int, stIndex : int, enIndex : int, points : Vector3[], inds : int[]) : KDTree
{
	var root : KDTree = this;
	root.KDTree();
	root.axis = depth % numDims;
	var splitPoint : int = FindPivotIndex(points, inds, stIndex, enIndex, root.axis);

	root.pivotIndex = inds[splitPoint];
	root.pivot = points[root.pivotIndex];
		
	var leftEndIndex : int = splitPoint - 1;
			
	if (leftEndIndex >= stIndex) 
	{
		root.lr[0] = MakeFromPointsInner(depth + 1, stIndex, leftEndIndex, points, inds);
	}
		
	var rightStartIndex : int = splitPoint + 1;
		
	if (rightStartIndex <= enIndex) 
	{
		root.lr[1] = MakeFromPointsInner(depth + 1, rightStartIndex, enIndex, points, inds);
	}
		
	return root;
}

//	Find a new pivot index from the range by splitting the points that fall either side
//	of its plane.
function FindPivotIndex(points : Vector3[], inds : int[], stIndex : int, enIndex : int, axis : int) : int
{
	var splitPoint : int = FindSplitPoint(points, inds, stIndex, enIndex, axis);
	// int splitPoint = Random.Range(stIndex, enIndex);

	var pivot : Vector3 = points[inds[splitPoint]];
	SwapElements(inds, stIndex, splitPoint);

	var currPt : int = stIndex + 1;
	var endPt : int = enIndex;
		
	while (currPt <= endPt) 
	{
		var curr : Vector3 = points[inds[currPt]];
			
		if ((curr[axis] > pivot[axis])) 
		{
			SwapElements(inds, currPt, endPt);
			endPt--;
		} 
		else 
		{
			SwapElements(inds, currPt - 1, currPt);
			currPt++;
		}
	}
		
	return currPt - 1;
}

function SwapElements(arr : int[], a : int, b : int) : void
{
	var temp : int = arr[a];
	arr[a] = arr[b];
	arr[b] = temp;
}
	

//	Simple "median of three" heuristic to find a reasonable splitting plane.
function FindSplitPoint(points : Vector3[], inds : int[], stIndex : int, enIndex : int, axis : int) : int 
{
	var a : float = points[inds[stIndex]][axis];
	var b : float = points[inds[enIndex]][axis];
	var midIndex : int = (stIndex + enIndex) / 2;
	var m : float = points[inds[midIndex]][axis];
		
	if (a > b) 
	{
		if (m > a) 
		{
			return stIndex;
		}
			
		if (b > m) 
		{
			return enIndex;
		}
			
		return midIndex;
	} 
	else 
	{
		if (a > m) 
		{
			return stIndex;
		}
			
		if (m > b) 
		{
			return enIndex;
		}
			
		return midIndex;
	}
}

function Iota(num : int) : int[] 
{
	var result : int[] = new int[num];
		
	for (var i : int = 0; i < num; i++) 
	{
		result[i] = i;
	}
		
	return result;
}

//	Find the nearest point in the set to the supplied point.
function FindNearest(pt : Vector3) : int 
{
	var bestSqDist : float = Mathf.Infinity;
	var bestIndex : int = -1;
		
	Search(pt, bestSqDist, bestIndex);
		
	return bestIndex;
}
	

//	Recursively search the tree.
function Search(pt : Vector3, bestSqSoFar : float, bestIndex : int) : void
{
	var mySqDist : float = (pivot - pt).sqrMagnitude;
		
	if (mySqDist < bestSqSoFar) 
	{
		bestSqSoFar = mySqDist;
		bestIndex = pivotIndex;
	}

	var planeDist : float = pt[axis] - pivot[axis]; //DistFromSplitPlane(pt, pivot, axis);
		
	var selector : int = planeDist <= 0 ? 0 : 1;
		
	if (lr[selector] != null) 
	{
		lr[selector].Search(pt, bestSqSoFar, bestIndex);
	}
		
	selector = (selector + 1) % 2;
		
	var sqPlaneDist : float = planeDist * planeDist;

	if ((lr[selector] != null)  (bestSqSoFar > sqPlaneDist)) 
	{
		lr[selector].Search(pt, bestSqSoFar, bestIndex);
	}
}
	

//	Get a point's distance from an axis-aligned plane.
function DistFromSplitPlane(pt : Vector3, planePt : Vector3, axis : int) : float
{
	return pt[axis] - planePt[axis];
}

and WayPoints.js

var index = 10;
var scale = 50.0;
var wayPoints : Vector3[];

var kd : KDTree;

function Start () 
{	
	wayPoints = new Vector3[index];
	
	for (var wayPoint : Vector3 in wayPoints)
	{	
		wayPoint.x = Random.Range(-scale, scale);
		wayPoint.y = Random.Range(-scale, scale);
		wayPoint.z = Random.Range(-scale, scale);
	}
	
	kd.MakeFromPoints(wayPoints);
	Debug.Log(kd.FindNearest(Vector3.zero));
}

function OnDrawGizmos()
{
	Gizmos.color = Color.white;
	for (var wayPoint : Vector3 in wayPoints)
	{
		Gizmos.DrawSphere(wayPoint, .5);
	}
}

In the Inspector I connect the KDTree Component to the kd variable in Waypoints.js.

Thanks in advance for any bits of wisdom.

Mitch

iByte,

Not to long ago I was playing around with players chasing a soccer ball. You probably already know this but the closest player to the ball is not always the best player to go get the ball.

It’s the player that can intercept the ball in the fastest time. And that’s a tricky task to compute. Not only do you have to factor in the velocity of the ball and the players, you also need to take into account player acceleration.

And then there’s the issue of the ball’s trajectory when it takes the ball over the player’s heads.

I had quite a fun time playing with all that last year. I had some moderate success but quite a few times the wrong player would go get the ball so my math was not quite right.

Mitch

Sorry it’s taken me a bit of time to get onto this…

@iByte: The kd-tree isn’t suitable for finding the nearest player to the ball in a sports game. The reason is that there is an overhead in preprocessing the data to build the tree at first - this takes longer than a simple search through the points to check which is closest. The kd-tree is only suitable when the set of points is fixed.

@MitchStan: You don’t actually need to reimplement the code in JS unless you really want to. You can just use the class from JS - place the kdtree.cs file in the project’s standard assets folder and it will be accessible to a JS script.

You can’t attach the class to an object directly because, as noted, it doesn’t derive from MonoBehaviour. You should declare a variable of type KDTree in your code and initialise it with the MakeFromPoints static function:-

var tree: KDTree;
var pointsArray: Vector3[];

function Start() {
    tree = KDTree.MakeFromPoints(pointsArray);
}

Having done that, you can search for the point nearest any arbitrary target point (a player’s position, say) using the FindNearest function:-

var nearest: int = tree.FindNearest(targetPoint);

This function returns an integer, which is an index into the original points array.

Thanks Andeeee. Works like a charm.

I was having fun trying to convert it to JS but I couldn’t get it to work. I think I’ll try for a few more days - just to learn.

Again, thanks!

Mitch

Hi Andeee,

When I type in the script you posted it comes up with the error: The name KDTree does not denote a valid type.

I’ve placed the kdtree.cs in the assets folder so I dunno what’s wrong.

Try placing the KDTree script file in the Standard Assets folder. You will need to do this if you are accessing the class from JavaScript.

Hi, andeeeee, very impressive work. By the way, how it would be easiest way to add or delete individual elements in this tree? I am interested to apply this algorithm for my RTS game. However, as game is dynamic (some warriors die, some are created new), vector3 array size would be always changing and I was interested if it would work correctly?

Hi all. A lot of time i’m searching a way to find dynamic mooving nearest object. I try to represent my map with zones. But this idea doesn’t like me, because i use planes with triggers.
I have a lot of objects. They move every time in different directions. And every frame they should find the nearest object to them. My algorithm work good, but it has some bugs. I learned little about kd-trees, but how does the work with dynamic objects? Is it fast? Thx and sorry for my english.

I managed to apply this algorithm for now, but I am facing some difficulties with finding not the first nearest neighbour, but k-th nearest neighbour. I modified the code like that:

    public void FindNearestR(int suitIndex, Vector3 pt, ref float sqrRmin) {
        float bestSqDist = 1000000000f;
        int bestIndex = -1;
       
       
       
        SearchR(pt, ref bestSqDist, ref bestIndex, ref sqrRmin);
       
    //    int i = bestIndex;
    //    int ii = i-1;
       
       
        sqrRmin = bestSqDist;
        suitIndex = bestIndex-1;
    }
   

//    Recursively search the tree.
    void SearchR(Vector3 pt, ref float bestSqSoFar, ref int bestIndex, ref float sqrRmin) {
        float mySqDist = (pivot - pt).sqrMagnitude;
       
        if(mySqDist > sqrRmin){
       
        if (mySqDist < bestSqSoFar) {
           
                bestSqSoFar = mySqDist;
                bestIndex = pivotIndex;
           
        //    else{
        //        Debug.Log(mySqDist);
        //    }
           
           
        }
       
        float planeDist = pt[axis] - pivot[axis]; //DistFromSplitPlane(pt, pivot, axis);
       
        int selector = planeDist <= 0 ? 0 : 1;
       
        if (lr[selector] != null) {
            lr[selector].SearchR(pt, ref bestSqSoFar, ref bestIndex, ref sqrRmin);
        }
       
        selector = (selector + 1) % 2;
       
        float sqPlaneDist = planeDist * planeDist;

        if ((lr[selector] != null) && (bestSqSoFar > sqPlaneDist)) {
            lr[selector].SearchR(pt, ref bestSqSoFar, ref bestIndex, ref sqrRmin);
        }
        }
    }

Here I used sqrRmin as minimum distance to allow NN to be used. For first NN sqrRmin=0, for 2nd NN it is becoming equal first NN sqrRmin and when function is called it should avoid counting first NN due to this condition. Unfortunately something wrong is going on with mySqDist > sqrRmin and I am always ending up with infinite loop. Does anyone see how it would be possible to fix this?

Great resource thank you. I am using it to place a collider on nearest tree as I have anywhere from 200k trees and up, however I am running into this issue when I put all of the trees into one array: maxVertices < 65536 && maxIndices < 65536*3

I checked to ensure my array could handle the amount of entities I was feeding it and that didn’t break until I put in about 10Mil Vector3’s, all I need is about 200-500k Vector3’s. If you can respond to the reasoning for this error and whether or not there may be a fix for it or if thats just the max it can do that would be great thank you.

Just want to point out that the code snippet quoted above is Java and not C Sharp, for any nube (like me) that come by this thread. This bit below.

var tree: KDTree;
var pointsArray: Vector3[];
function Start() {
    tree = KDTree.MakeFromPoints(pointsArray);
}

Here I updated this class with additional functions, which allows to search for k-th nearest neighbour(s) or return distances to them instead of indices.

P.S. I checked upon Unity 5 and it allows to build very large trees (I tried up to 1.5 million Vector3’s and it’s working fine).

2042507–132561–KDTree.cs (8.28 KB)

1 Like

Hi,
Is it possible to find all points in a radius around the test point?

You can keep calling:

FindNearestK(Vector3 pt, int k)

function with larger k orders until the distance to the found point becomes larger than your critical radius. This should run with O(mn log n) performance, where m would be number of points inside your radius. It’s going towards O(NN) if your cut-off radius is getting close to the point where all dataset points are being added (so keep your cut-off radius reasonable).

I should add soon that function into the class when I will be looking there next time.

This is exactly a functionality I need at the moment. I guess the common usecase here is to find neighbors to certain points and establish a connection to each point in a radius.
A proper method for that would be great, but right now i am going to do it the way you described

I think it depends what exactly you are doing. If you need to find just the nearest one (i.e. find the closest tree for worker to be chopped in the forest to collect logs) or to get the environment properties (i.e. find how big is the sphere for 6 neighbours, calculate densities, etc.). It is good tip to use fixed number of points, as the loop will take always the same time to calculate rather than the fixed radius (i.e. in radius = 5 you can find sometimes 3 neighbours, while other time you can find 3000 neighbours, and if you get very large number of neighbours, the loops will become very heavy and your game could freeze).

I just tested it for a grid that is a 20x20 units area with one point at each point (x, y). so (0,0),(1,0),(2,0),…,(0,1),(0,2),…,(i,j),…(19,19).

When I now try to find the neighbors within a 1 unit radius i get this as a result. Am i doing something wrong here? I use node.position (e.g. (14,6)) as an input position for the nearest k search

UPDATE: I think i found the reason for this to happen. Basically the nearest k-th algorithm is not safe for an environment where two points have the exact same distance to a certain position.

        if ( mySqDist < bestSqSoFar )
        {
            if ( mySqDist > minSqDist )
            {
                bestSqSoFar = mySqDist;
                bestIndex = pivotIndex;
            }
        }

This part ignores the second point with the same distance to the point as “minSqDist” and therefore will not be noted. This means that you might miss proper points here and the result of the k search will be false, since there is one (or more) point(s) that are nearer to the given location than the returned point

UPDATE 2:
Fixed the issue with the following changes:

    void SearchK( Vector3 pt, HashSet<int> p_pivotIndexSet, ref float bestSqSoFar, ref float minSqDist, ref int bestIndex )
    {
        float mySqDist = ( pivot - pt ).sqrMagnitude;

        if ( mySqDist < bestSqSoFar )
        {
            if ( mySqDist >= minSqDist && !p_pivotIndexSet.Contains( pivotIndex ) )
            {
                bestSqSoFar = mySqDist;
                bestIndex = pivotIndex;
            }
        }
    public int FindNearestK( Vector3 pt, int k )
    {
        // Find and returns    k-th nearest neighbour
        float bestSqDist = 1000000000f;
        float minSqDist = -1.0f;
        int bestIndex = -1;
        HashSet<int> _pivotIndexSet = new HashSet<int>();

        for ( int i = 0; i < k - 1; i++ )
        {
            SearchK( pt, _pivotIndexSet, ref bestSqDist, ref minSqDist, ref bestIndex );
            _pivotIndexSet.Add( bestIndex );

            minSqDist = bestSqDist;
            bestSqDist = 1000000000f;
            bestIndex = -1;
        }

        SearchK( pt, _pivotIndexSet, ref bestSqDist, ref minSqDist, ref bestIndex );

        return bestIndex;
    }