aboutsummaryrefslogtreecommitdiff
path: root/gensvm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/util.py')
-rw-r--r--gensvm/util.py31
1 files changed, 18 insertions, 13 deletions
diff --git a/gensvm/util.py b/gensvm/util.py
index 0b7cd1d..046f3be 100644
--- a/gensvm/util.py
+++ b/gensvm/util.py
@@ -9,10 +9,11 @@ Utility functions for GenSVM
import numpy as np
-def get_ranks(x):
+def get_ranks(a):
"""
Rank data in an array. Low values get a small rank number. Ties are broken
- by assigning the lowest value.
+ by assigning the lowest value (this corresponds to ``rankdata(a,
+ method='min')`` in SciPy.
Examples
--------
@@ -21,14 +22,18 @@ def get_ranks(x):
[4, 1, 3, 1, 5, 6, 7]
"""
- x = np.ravel(np.asarray(x))
- l = len(x)
- r = 1
- ranks = np.zeros((l,))
- while not all([k is None for k in x]):
- m = min([k for k in x if not k is None])
- idx = [1 if k == m else 0 for k in x]
- ranks = [r if idx[k] else ranks[k] for k in range(l)]
- r += sum(idx)
- x = [None if idx[k] else x[k] for k in range(l)]
- return ranks
+ orig = np.ravel(np.asarray(a))
+ arr = orig[~np.isnan(orig)]
+ sorter = np.argsort(arr, kind="quicksort")
+ inv = np.empty(sorter.size, dtype=np.intp)
+ inv[sorter] = np.arange(sorter.size, dtype=np.intp)
+
+ arr = arr[sorter]
+ obs = np.r_[True, arr[1:] != arr[:-1]]
+ dense = obs.cumsum()[inv]
+
+ count = np.r_[np.nonzero(obs)[0], len(obs)]
+ ranks = np.zeros_like(orig)
+ ranks[~np.isnan(orig)] = count[dense - 1] + 1
+ ranks[np.isnan(orig)] = np.max(ranks) + 1
+ return list(ranks)