mirror of
https://github.com/qgis/QGIS.git
synced 2025-10-04 00:04:03 -04:00
add kmeans++
This commit is contained in:
parent
5500f43ebc
commit
9a4b4d388a
@ -17,6 +17,7 @@
|
||||
|
||||
#include "qgsalgorithmkmeansclustering.h"
|
||||
#include <unordered_map>
|
||||
#include <random>
|
||||
|
||||
///@cond PRIVATE
|
||||
|
||||
@ -52,6 +53,11 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & )
|
||||
addParameter( new QgsProcessingParameterFeatureSource( QStringLiteral( "INPUT" ), QObject::tr( "Input layer" ), QList<int>() << static_cast<int>( Qgis::ProcessingSourceType::VectorAnyGeometry ) ) );
|
||||
addParameter( new QgsProcessingParameterNumber( QStringLiteral( "CLUSTERS" ), QObject::tr( "Number of clusters" ), Qgis::ProcessingNumberParameterType::Integer, 5, false, 1 ) );
|
||||
|
||||
QStringList initializationMethods;
|
||||
initializationMethods << QObject::tr( "Farthest points" )
|
||||
<< QObject::tr( "K-means++" );
|
||||
addParameter( new QgsProcessingParameterEnum( QStringLiteral( "METHOD" ), QObject::tr( "Method" ), initializationMethods, false, 0, false ) );
|
||||
|
||||
auto fieldNameParam = std::make_unique<QgsProcessingParameterString>( QStringLiteral( "FIELD_NAME" ), QObject::tr( "Cluster field name" ), QStringLiteral( "CLUSTER_ID" ) );
|
||||
fieldNameParam->setFlags( fieldNameParam->flags() | Qgis::ProcessingParameterFlag::Advanced );
|
||||
addParameter( fieldNameParam.release() );
|
||||
@ -65,7 +71,10 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & )
|
||||
QString QgsKMeansClusteringAlgorithm::shortHelpString() const
|
||||
{
|
||||
return QObject::tr( "This algorithm calculates the 2D distance based k-means cluster number for each input feature.\n\n"
|
||||
"If input geometries are lines or polygons, the clustering is based on the centroid of the feature." );
|
||||
"If input geometries are lines or polygons, the clustering is based on the centroid of the feature.\n\n"
|
||||
"References:\n"
|
||||
"Arthur, David & Vassilvitskii, Sergei. (2007). K-Means++: The Advantages of Careful Seeding. Proc. of the Annu. ACM-SIAM Symp. on Discrete Algorithms. 8.\n\n"
|
||||
"Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, Greedy and Not So Greedy k-means++");
|
||||
}
|
||||
|
||||
QString QgsKMeansClusteringAlgorithm::shortDescription() const
|
||||
@ -85,6 +94,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
|
||||
throw QgsProcessingException( invalidSourceError( parameters, QStringLiteral( "INPUT" ) ) );
|
||||
|
||||
int k = parameterAsInt( parameters, QStringLiteral( "CLUSTERS" ), context );
|
||||
int initializationMethod = parameterAsInt( parameters, QStringLiteral( "METHOD" ), context );
|
||||
|
||||
QgsFields outputFields = source->fields();
|
||||
QgsFields newFields;
|
||||
@ -153,8 +163,17 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
|
||||
|
||||
// cluster centers
|
||||
std::vector<QgsPointXY> centers( k );
|
||||
|
||||
initClusters( clusterFeatures, centers, k, feedback );
|
||||
switch ( initializationMethod )
|
||||
{
|
||||
case 0: // farthest points
|
||||
initClustersFarthestPoints( clusterFeatures, centers, k, feedback );
|
||||
break;
|
||||
case 1: // k-means++
|
||||
initClustersPlusPlus( clusterFeatures, centers, k, feedback );
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
calculateKMeans( clusterFeatures, centers, k, feedback );
|
||||
}
|
||||
|
||||
@ -203,7 +222,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
|
||||
|
||||
// ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c
|
||||
|
||||
void QgsKMeansClusteringAlgorithm::initClusters( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, const int k, QgsProcessingFeedback *feedback )
|
||||
void QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, const int k, QgsProcessingFeedback *feedback )
|
||||
{
|
||||
const std::size_t n = points.size();
|
||||
if ( n == 0 )
|
||||
@ -303,6 +322,140 @@ void QgsKMeansClusteringAlgorithm::initClusters( std::vector<Feature> &points, s
|
||||
}
|
||||
}
|
||||
|
||||
void QgsKMeansClusteringAlgorithm::initClustersPlusPlus( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, const int k, QgsProcessingFeedback *feedback )
|
||||
{
|
||||
const std::size_t n = points.size();
|
||||
if ( n == 0 )
|
||||
return;
|
||||
|
||||
if ( n == 1 )
|
||||
{
|
||||
for ( int i = 0; i < k; i++ )
|
||||
centers[i] = points[0].point;
|
||||
return;
|
||||
}
|
||||
|
||||
// randomly select the first point
|
||||
std::random_device rd;
|
||||
std::mt19937 gen( rd() );
|
||||
std::uniform_int_distribution<size_t> distrib( 0, n - 1 );
|
||||
|
||||
int p1 = distrib( gen );
|
||||
centers[0] = points[p1].point;
|
||||
|
||||
// calculate distances and total error (sum of distances of points to center)
|
||||
std::vector<double> distances( n );
|
||||
double totalError = 0;
|
||||
long duplicateCount = 1;
|
||||
for ( size_t i = 0; i < n; i++ )
|
||||
{
|
||||
double distance = points[i].point.sqrDist( centers[0] );
|
||||
distances[i] = distance;
|
||||
totalError += distance;
|
||||
if ( qgsDoubleNear( distance, 0 ) )
|
||||
{
|
||||
duplicateCount++;
|
||||
}
|
||||
}
|
||||
if ( feedback && duplicateCount > 1 )
|
||||
{
|
||||
feedback->pushInfo( QObject::tr( "There are at least %n duplicate input(s), the number of output clusters may be less than was requested", nullptr, duplicateCount ) );
|
||||
}
|
||||
|
||||
// greedy kmeans++
|
||||
// test not only one center but L possible centers
|
||||
// chosen independently according to the same probability distribution), and then among these L
|
||||
// centers, the one that decreases the k-means cost the most is chosen
|
||||
// Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, greedy and Not So greedy k-means++
|
||||
unsigned int numCandidateCenters = 2 + std::floor( std::log( k ) );
|
||||
std::vector<double> randomNumbers( numCandidateCenters );
|
||||
std::vector<size_t> candidateCenters( numCandidateCenters );
|
||||
|
||||
std::uniform_real_distribution<double> dis( 0.0, 1.0 );
|
||||
for ( int i = 1; i < k; i++ )
|
||||
{
|
||||
// sampling with probability proportional to the squared distance to the closest existing center
|
||||
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
|
||||
{
|
||||
randomNumbers[j] = dis( gen ) * totalError;
|
||||
}
|
||||
|
||||
// cumulative sum, keep distances for later use
|
||||
std::vector<double> cumSum = distances;
|
||||
for ( size_t j = 1; j < n; j++ )
|
||||
{
|
||||
cumSum[j] += cumSum[j - 1];
|
||||
}
|
||||
|
||||
// binary search for the index of the first element greater than or equal to random numbers
|
||||
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
|
||||
{
|
||||
size_t low = 0;
|
||||
size_t high = n - 1;
|
||||
|
||||
while ( low <= high )
|
||||
{
|
||||
size_t mid = low + ( high - low ) / 2;
|
||||
if ( cumSum[mid] < randomNumbers[j] )
|
||||
{
|
||||
low = mid + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// size_t cannot be negative
|
||||
if ( mid == 0 )
|
||||
break;
|
||||
|
||||
high = mid - 1;
|
||||
}
|
||||
}
|
||||
// clip candidate center to the number of points
|
||||
if ( low >= n )
|
||||
{
|
||||
low = n - 1;
|
||||
}
|
||||
candidateCenters[j] = low;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> distancesCandidateCenters( numCandidateCenters, std::vector<double>( n ) );;
|
||||
|
||||
// store distances between previous and new candidate center, error and best candidate index
|
||||
double currentError = 0;
|
||||
double lowestError = std::numeric_limits<double>::max();
|
||||
size_t bestCandidateIndex;
|
||||
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
|
||||
{
|
||||
for ( size_t z = 0; z < n; z++ )
|
||||
{
|
||||
// distance to candidate center
|
||||
double distance = points[candidateCenters[j]].point.sqrDist( points[z].point );
|
||||
// if distance to previous center is less than the current distance, use that
|
||||
if ( distance > distances[z] )
|
||||
{
|
||||
distance = distances[z];
|
||||
}
|
||||
distancesCandidateCenters[j][z] = distance;
|
||||
currentError += distance;
|
||||
}
|
||||
if ( lowestError > currentError )
|
||||
{
|
||||
lowestError = currentError;
|
||||
bestCandidateIndex = j;
|
||||
}
|
||||
}
|
||||
|
||||
// update distances with the best candidate center values
|
||||
for ( size_t j = 0; j < n; j++ )
|
||||
{
|
||||
distances[j] = distancesCandidateCenters[bestCandidateIndex][j];
|
||||
}
|
||||
// store the best candidate center
|
||||
centers[i] = points[candidateCenters[bestCandidateIndex]].point;
|
||||
// update error
|
||||
totalError = lowestError;
|
||||
}
|
||||
}
|
||||
|
||||
// ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c
|
||||
|
||||
void QgsKMeansClusteringAlgorithm::calculateKMeans( std::vector<QgsKMeansClusteringAlgorithm::Feature> &objs, std::vector<QgsPointXY> ¢ers, int k, QgsProcessingFeedback *feedback )
|
||||
|
@ -58,7 +58,8 @@ class ANALYSIS_EXPORT QgsKMeansClusteringAlgorithm : public QgsProcessingAlgorit
|
||||
int cluster = -1;
|
||||
};
|
||||
|
||||
static void initClusters( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, int k, QgsProcessingFeedback *feedback );
|
||||
static void initClustersFarthestPoints( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, int k, QgsProcessingFeedback *feedback );
|
||||
static void initClustersPlusPlus( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, int k, QgsProcessingFeedback *feedback );
|
||||
static void calculateKMeans( std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, int k, QgsProcessingFeedback *feedback );
|
||||
static void findNearest( std::vector<Feature> &points, const std::vector<QgsPointXY> ¢ers, int k, bool &changed );
|
||||
static void updateMeans( const std::vector<Feature> &points, std::vector<QgsPointXY> ¢ers, std::vector<uint> &weights, int k );
|
||||
|
@ -1016,44 +1016,65 @@ void TestQgsProcessingAlgsPt1::kmeansCluster()
|
||||
|
||||
// no features, no crash
|
||||
int k = 2;
|
||||
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
|
||||
// farthest points
|
||||
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
// kmeans++
|
||||
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
|
||||
// features < clusters
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 5 ) ) );
|
||||
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
|
||||
// farthest points
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 1 ) ) );
|
||||
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, 0 );
|
||||
// kmeans++
|
||||
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, 0 );
|
||||
|
||||
// features == clusters
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 11, 5 ) ) );
|
||||
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 1 ) ) );
|
||||
// farthest points
|
||||
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, 1 );
|
||||
QCOMPARE( features[1].cluster, 0 );
|
||||
// kmeans++
|
||||
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QVERIFY( features[0].cluster != features[1].cluster );
|
||||
|
||||
// features > clusters
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 3 ) ) );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 13 ) ) );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 23, 6 ) ) );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 2, 8 ) ) );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 10 ) ) );
|
||||
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 10 ) ) );
|
||||
k = 2;
|
||||
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
|
||||
// farthest points
|
||||
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, 1 );
|
||||
QCOMPARE( features[1].cluster, 1 );
|
||||
QCOMPARE( features[2].cluster, 0 );
|
||||
QCOMPARE( features[3].cluster, 0 );
|
||||
QCOMPARE( features[4].cluster, 0 );
|
||||
// kmeans++
|
||||
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, features[1].cluster );
|
||||
QCOMPARE( features[2].cluster, features[3].cluster );
|
||||
QCOMPARE( features[4].cluster, features[3].cluster );
|
||||
|
||||
// repeat above, with 3 clusters
|
||||
k = 3;
|
||||
centers.resize( 3 );
|
||||
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
|
||||
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
|
||||
QCOMPARE( features[0].cluster, 1 );
|
||||
QCOMPARE( features[1].cluster, 2 );
|
||||
QCOMPARE( features[1].cluster, 1 );
|
||||
QCOMPARE( features[2].cluster, 2 );
|
||||
QCOMPARE( features[3].cluster, 2 );
|
||||
QCOMPARE( features[3].cluster, 0 );
|
||||
QCOMPARE( features[4].cluster, 0 );
|
||||
|
||||
// with identical points
|
||||
|
Loading…
x
Reference in New Issue
Block a user