add kmeans++

This commit is contained in:
viperminiq 2025-05-08 10:50:29 +02:00 committed by Nyall Dawson
parent 5500f43ebc
commit 9a4b4d388a
3 changed files with 192 additions and 17 deletions

View File

@ -17,6 +17,7 @@
#include "qgsalgorithmkmeansclustering.h"
#include <unordered_map>
#include <random>
///@cond PRIVATE
@ -52,6 +53,11 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & )
addParameter( new QgsProcessingParameterFeatureSource( QStringLiteral( "INPUT" ), QObject::tr( "Input layer" ), QList<int>() << static_cast<int>( Qgis::ProcessingSourceType::VectorAnyGeometry ) ) );
addParameter( new QgsProcessingParameterNumber( QStringLiteral( "CLUSTERS" ), QObject::tr( "Number of clusters" ), Qgis::ProcessingNumberParameterType::Integer, 5, false, 1 ) );
QStringList initializationMethods;
initializationMethods << QObject::tr( "Farthest points" )
<< QObject::tr( "K-means++" );
addParameter( new QgsProcessingParameterEnum( QStringLiteral( "METHOD" ), QObject::tr( "Method" ), initializationMethods, false, 0, false ) );
auto fieldNameParam = std::make_unique<QgsProcessingParameterString>( QStringLiteral( "FIELD_NAME" ), QObject::tr( "Cluster field name" ), QStringLiteral( "CLUSTER_ID" ) );
fieldNameParam->setFlags( fieldNameParam->flags() | Qgis::ProcessingParameterFlag::Advanced );
addParameter( fieldNameParam.release() );
@ -65,7 +71,10 @@ void QgsKMeansClusteringAlgorithm::initAlgorithm( const QVariantMap & )
QString QgsKMeansClusteringAlgorithm::shortHelpString() const
{
return QObject::tr( "This algorithm calculates the 2D distance based k-means cluster number for each input feature.\n\n"
"If input geometries are lines or polygons, the clustering is based on the centroid of the feature." );
"If input geometries are lines or polygons, the clustering is based on the centroid of the feature.\n\n"
"References:\n"
"Arthur, David & Vassilvitskii, Sergei. (2007). K-Means++: The Advantages of Careful Seeding. Proc. of the Annu. ACM-SIAM Symp. on Discrete Algorithms. 8.\n\n"
"Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, Greedy and Not So Greedy k-means++");
}
QString QgsKMeansClusteringAlgorithm::shortDescription() const
@ -85,6 +94,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
throw QgsProcessingException( invalidSourceError( parameters, QStringLiteral( "INPUT" ) ) );
int k = parameterAsInt( parameters, QStringLiteral( "CLUSTERS" ), context );
int initializationMethod = parameterAsInt( parameters, QStringLiteral( "METHOD" ), context );
QgsFields outputFields = source->fields();
QgsFields newFields;
@ -153,8 +163,17 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
// cluster centers
std::vector<QgsPointXY> centers( k );
initClusters( clusterFeatures, centers, k, feedback );
switch ( initializationMethod )
{
case 0: // farthest points
initClustersFarthestPoints( clusterFeatures, centers, k, feedback );
break;
case 1: // k-means++
initClustersPlusPlus( clusterFeatures, centers, k, feedback );
break;
default:
break;
}
calculateKMeans( clusterFeatures, centers, k, feedback );
}
@ -203,7 +222,7 @@ QVariantMap QgsKMeansClusteringAlgorithm::processAlgorithm( const QVariantMap &p
// ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c
void QgsKMeansClusteringAlgorithm::initClusters( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, const int k, QgsProcessingFeedback *feedback )
void QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, const int k, QgsProcessingFeedback *feedback )
{
const std::size_t n = points.size();
if ( n == 0 )
@ -303,6 +322,140 @@ void QgsKMeansClusteringAlgorithm::initClusters( std::vector<Feature> &points, s
}
}
void QgsKMeansClusteringAlgorithm::initClustersPlusPlus( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, const int k, QgsProcessingFeedback *feedback )
{
const std::size_t n = points.size();
if ( n == 0 )
return;
if ( n == 1 )
{
for ( int i = 0; i < k; i++ )
centers[i] = points[0].point;
return;
}
// randomly select the first point
std::random_device rd;
std::mt19937 gen( rd() );
std::uniform_int_distribution<size_t> distrib( 0, n - 1 );
int p1 = distrib( gen );
centers[0] = points[p1].point;
// calculate distances and total error (sum of distances of points to center)
std::vector<double> distances( n );
double totalError = 0;
long duplicateCount = 1;
for ( size_t i = 0; i < n; i++ )
{
double distance = points[i].point.sqrDist( centers[0] );
distances[i] = distance;
totalError += distance;
if ( qgsDoubleNear( distance, 0 ) )
{
duplicateCount++;
}
}
if ( feedback && duplicateCount > 1 )
{
feedback->pushInfo( QObject::tr( "There are at least %n duplicate input(s), the number of output clusters may be less than was requested", nullptr, duplicateCount ) );
}
// greedy kmeans++
// test not only one center but L possible centers
// chosen independently according to the same probability distribution), and then among these L
// centers, the one that decreases the k-means cost the most is chosen
// Bhattacharya, Anup & Eube, Jan & Röglin, Heiko & Schmidt, Melanie. (2019). Noisy, greedy and Not So greedy k-means++
unsigned int numCandidateCenters = 2 + std::floor( std::log( k ) );
std::vector<double> randomNumbers( numCandidateCenters );
std::vector<size_t> candidateCenters( numCandidateCenters );
std::uniform_real_distribution<double> dis( 0.0, 1.0 );
for ( int i = 1; i < k; i++ )
{
// sampling with probability proportional to the squared distance to the closest existing center
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
{
randomNumbers[j] = dis( gen ) * totalError;
}
// cumulative sum, keep distances for later use
std::vector<double> cumSum = distances;
for ( size_t j = 1; j < n; j++ )
{
cumSum[j] += cumSum[j - 1];
}
// binary search for the index of the first element greater than or equal to random numbers
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
{
size_t low = 0;
size_t high = n - 1;
while ( low <= high )
{
size_t mid = low + ( high - low ) / 2;
if ( cumSum[mid] < randomNumbers[j] )
{
low = mid + 1;
}
else
{
// size_t cannot be negative
if ( mid == 0 )
break;
high = mid - 1;
}
}
// clip candidate center to the number of points
if ( low >= n )
{
low = n - 1;
}
candidateCenters[j] = low;
}
std::vector<std::vector<double>> distancesCandidateCenters( numCandidateCenters, std::vector<double>( n ) );;
// store distances between previous and new candidate center, error and best candidate index
double currentError = 0;
double lowestError = std::numeric_limits<double>::max();
size_t bestCandidateIndex;
for ( unsigned int j = 0; j < numCandidateCenters; j++ )
{
for ( size_t z = 0; z < n; z++ )
{
// distance to candidate center
double distance = points[candidateCenters[j]].point.sqrDist( points[z].point );
// if distance to previous center is less than the current distance, use that
if ( distance > distances[z] )
{
distance = distances[z];
}
distancesCandidateCenters[j][z] = distance;
currentError += distance;
}
if ( lowestError > currentError )
{
lowestError = currentError;
bestCandidateIndex = j;
}
}
// update distances with the best candidate center values
for ( size_t j = 0; j < n; j++ )
{
distances[j] = distancesCandidateCenters[bestCandidateIndex][j];
}
// store the best candidate center
centers[i] = points[candidateCenters[bestCandidateIndex]].point;
// update error
totalError = lowestError;
}
}
// ported from https://github.com/postgis/postgis/blob/svn-trunk/liblwgeom/lwkmeans.c
void QgsKMeansClusteringAlgorithm::calculateKMeans( std::vector<QgsKMeansClusteringAlgorithm::Feature> &objs, std::vector<QgsPointXY> &centers, int k, QgsProcessingFeedback *feedback )

View File

@ -58,7 +58,8 @@ class ANALYSIS_EXPORT QgsKMeansClusteringAlgorithm : public QgsProcessingAlgorit
int cluster = -1;
};
static void initClusters( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, int k, QgsProcessingFeedback *feedback );
static void initClustersFarthestPoints( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, int k, QgsProcessingFeedback *feedback );
static void initClustersPlusPlus( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, int k, QgsProcessingFeedback *feedback );
static void calculateKMeans( std::vector<Feature> &points, std::vector<QgsPointXY> &centers, int k, QgsProcessingFeedback *feedback );
static void findNearest( std::vector<Feature> &points, const std::vector<QgsPointXY> &centers, int k, bool &changed );
static void updateMeans( const std::vector<Feature> &points, std::vector<QgsPointXY> &centers, std::vector<uint> &weights, int k );

View File

@ -1016,44 +1016,65 @@ void TestQgsProcessingAlgsPt1::kmeansCluster()
// no features, no crash
int k = 2;
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
// farthest points
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
// kmeans++
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
// features < clusters
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 5 ) ) );
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
// farthest points
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 1 ) ) );
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, 0 );
// kmeans++
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, 0 );
// features == clusters
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 11, 5 ) ) );
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 1 ) ) );
// farthest points
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, 1 );
QCOMPARE( features[1].cluster, 0 );
// kmeans++
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QVERIFY( features[0].cluster != features[1].cluster );
// features > clusters
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 3 ) ) );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 13, 13 ) ) );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 23, 6 ) ) );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 2, 8 ) ) );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 1, 10 ) ) );
features.emplace_back( QgsKMeansClusteringAlgorithm::Feature( QgsPointXY( 3, 10 ) ) );
k = 2;
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
// farthest points
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, 1 );
QCOMPARE( features[1].cluster, 1 );
QCOMPARE( features[2].cluster, 0 );
QCOMPARE( features[3].cluster, 0 );
QCOMPARE( features[4].cluster, 0 );
// kmeans++
QgsKMeansClusteringAlgorithm::initClustersPlusPlus( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, features[1].cluster );
QCOMPARE( features[2].cluster, features[3].cluster );
QCOMPARE( features[4].cluster, features[3].cluster );
// repeat above, with 3 clusters
k = 3;
centers.resize( 3 );
QgsKMeansClusteringAlgorithm::initClusters( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::initClustersFarthestPoints( features, centers, k, nullptr );
QgsKMeansClusteringAlgorithm::calculateKMeans( features, centers, k, nullptr );
QCOMPARE( features[0].cluster, 1 );
QCOMPARE( features[1].cluster, 2 );
QCOMPARE( features[1].cluster, 1 );
QCOMPARE( features[2].cluster, 2 );
QCOMPARE( features[3].cluster, 2 );
QCOMPARE( features[3].cluster, 0 );
QCOMPARE( features[4].cluster, 0 );
// with identical points