Skip to content

Commit

Permalink
add docstring to PointSet.find_clusters()
Browse files Browse the repository at this point in the history
  • Loading branch information
AhmetNSimsek committed Oct 23, 2024
1 parent 599ee62 commit 73edcc6
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions siibra/locations/pointset.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,34 @@ def homogeneous(self):
"""Access the list of 3D point as an Nx4 array of homogeneous coordinates."""
return np.c_[self.coordinates, np.ones(len(self))]

def find_clusters(self, min_fraction=1 / 200, max_fraction=1 / 8):
def find_clusters(
self,
min_fraction: float = 1 / 200,
max_fraction: float = 1 / 8
) -> List[int]:
"""
Find clusters using HDBSCAN (https://dl.acm.org/doi/10.1145/2733381)
implementation of scikit-learn (https://dl.acm.org/doi/10.5555/1953048.2078195).
Parameters
----------
min_fraction: min cluster size as a fraction of total points in the PointSet
max_fraction: max cluster size as a fraction of total points in the PointSet
Returns
-------
List[int]
Returns the cluster labels found by skilearn.cluster.HDBSCAN.
Note
----
Replaces the labels of the PointSet instance with these labels.
Raises
------
RuntimeError
If a sklearn version without HDBSCAN is installed.
"""
if not _HAS_HDBSCAN:
raise RuntimeError(
f"HDBSCAN is not available with your version {sklearn.__version__} "
Expand All @@ -289,7 +316,9 @@ def find_clusters(self, min_fraction=1 / 200, max_fraction=1 / 8):
max_cluster_size=int(N * max_fraction),
)
if self.labels is not None:
logger.warn("Existing labels of PointSet will be overwritten with cluster labels.")
logger.warning(
"Existing labels of PointSet will be overwritten with cluster labels."
)
self.labels = clustering.fit_predict(points)
return self.labels

Expand Down

0 comments on commit 73edcc6

Please sign in to comment.