RosettaCodeData/Task/Closest-pair-problem/Java/closest-pair-problem.java

187 lines
5.2 KiB
Java

import java.util.*;
public class ClosestPair
{
public static class Point
{
public final double x;
public final double y;
public Point(double x, double y)
{
this.x = x;
this.y = y;
}
public String toString()
{ return "(" + x + ", " + y + ")"; }
}
public static class Pair
{
public Point point1 = null;
public Point point2 = null;
public double distance = 0.0;
public Pair()
{ }
public Pair(Point point1, Point point2)
{
this.point1 = point1;
this.point2 = point2;
calcDistance();
}
public void update(Point point1, Point point2, double distance)
{
this.point1 = point1;
this.point2 = point2;
this.distance = distance;
}
public void calcDistance()
{ this.distance = distance(point1, point2); }
public String toString()
{ return point1 + "-" + point2 + " : " + distance; }
}
public static double distance(Point p1, Point p2)
{
double xdist = p2.x - p1.x;
double ydist = p2.y - p1.y;
return Math.hypot(xdist, ydist);
}
public static Pair bruteForce(List<? extends Point> points)
{
int numPoints = points.size();
if (numPoints < 2)
return null;
Pair pair = new Pair(points.get(0), points.get(1));
if (numPoints > 2)
{
for (int i = 0; i < numPoints - 1; i++)
{
Point point1 = points.get(i);
for (int j = i + 1; j < numPoints; j++)
{
Point point2 = points.get(j);
double distance = distance(point1, point2);
if (distance < pair.distance)
pair.update(point1, point2, distance);
}
}
}
return pair;
}
public static void sortByX(List<? extends Point> points)
{
Collections.sort(points, new Comparator<Point>() {
public int compare(Point point1, Point point2)
{
if (point1.x < point2.x)
return -1;
if (point1.x > point2.x)
return 1;
return 0;
}
}
);
}
public static void sortByY(List<? extends Point> points)
{
Collections.sort(points, new Comparator<Point>() {
public int compare(Point point1, Point point2)
{
if (point1.y < point2.y)
return -1;
if (point1.y > point2.y)
return 1;
return 0;
}
}
);
}
public static Pair divideAndConquer(List<? extends Point> points)
{
List<Point> pointsSortedByX = new ArrayList<Point>(points);
sortByX(pointsSortedByX);
List<Point> pointsSortedByY = new ArrayList<Point>(points);
sortByY(pointsSortedByY);
return divideAndConquer(pointsSortedByX, pointsSortedByY);
}
private static Pair divideAndConquer(List<? extends Point> pointsSortedByX, List<? extends Point> pointsSortedByY)
{
int numPoints = pointsSortedByX.size();
if (numPoints <= 3)
return bruteForce(pointsSortedByX);
int dividingIndex = numPoints >>> 1;
List<? extends Point> leftOfCenter = pointsSortedByX.subList(0, dividingIndex);
List<? extends Point> rightOfCenter = pointsSortedByX.subList(dividingIndex, numPoints);
List<Point> tempList = new ArrayList<Point>(leftOfCenter);
sortByY(tempList);
Pair closestPair = divideAndConquer(leftOfCenter, tempList);
tempList.clear();
tempList.addAll(rightOfCenter);
sortByY(tempList);
Pair closestPairRight = divideAndConquer(rightOfCenter, tempList);
if (closestPairRight.distance < closestPair.distance)
closestPair = closestPairRight;
tempList.clear();
double shortestDistance =closestPair.distance;
double centerX = rightOfCenter.get(0).x;
for (Point point : pointsSortedByY)
if (Math.abs(centerX - point.x) < shortestDistance)
tempList.add(point);
for (int i = 0; i < tempList.size() - 1; i++)
{
Point point1 = tempList.get(i);
for (int j = i + 1; j < tempList.size(); j++)
{
Point point2 = tempList.get(j);
if ((point2.y - point1.y) >= shortestDistance)
break;
double distance = distance(point1, point2);
if (distance < closestPair.distance)
{
closestPair.update(point1, point2, distance);
shortestDistance = distance;
}
}
}
return closestPair;
}
public static void main(String[] args)
{
int numPoints = (args.length == 0) ? 1000 : Integer.parseInt(args[0]);
List<Point> points = new ArrayList<Point>();
Random r = new Random();
for (int i = 0; i < numPoints; i++)
points.add(new Point(r.nextDouble(), r.nextDouble()));
System.out.println("Generated " + numPoints + " random points");
long startTime = System.currentTimeMillis();
Pair bruteForceClosestPair = bruteForce(points);
long elapsedTime = System.currentTimeMillis() - startTime;
System.out.println("Brute force (" + elapsedTime + " ms): " + bruteForceClosestPair);
startTime = System.currentTimeMillis();
Pair dqClosestPair = divideAndConquer(points);
elapsedTime = System.currentTimeMillis() - startTime;
System.out.println("Divide and conquer (" + elapsedTime + " ms): " + dqClosestPair);
if (bruteForceClosestPair.distance != dqClosestPair.distance)
System.out.println("MISMATCH");
}
}