diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:24:33 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 12:24:33 -0500 |
| commit | 61423e7d4b98eeb8bb73ebb5786bd1477d99ee23 (patch) | |
| tree | 0ab7f1fa1c05b7451d7b7cd3cd867b93f0b988f2 /gensvm/util.py | |
| parent | Extract durations array (diff) | |
| download | pygensvm-61423e7d4b98eeb8bb73ebb5786bd1477d99ee23.tar.gz pygensvm-61423e7d4b98eeb8bb73ebb5786bd1477d99ee23.zip | |
Add support for interrupted grid search
Diffstat (limited to 'gensvm/util.py')
| -rw-r--r-- | gensvm/util.py | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/gensvm/util.py b/gensvm/util.py index 0b7cd1d..046f3be 100644 --- a/gensvm/util.py +++ b/gensvm/util.py @@ -9,10 +9,11 @@ Utility functions for GenSVM import numpy as np -def get_ranks(x): +def get_ranks(a): """ Rank data in an array. Low values get a small rank number. Ties are broken - by assigning the lowest value. + by assigning the lowest value (this corresponds to ``rankdata(a, + method='min')`` in SciPy. Examples -------- @@ -21,14 +22,18 @@ def get_ranks(x): [4, 1, 3, 1, 5, 6, 7] """ - x = np.ravel(np.asarray(x)) - l = len(x) - r = 1 - ranks = np.zeros((l,)) - while not all([k is None for k in x]): - m = min([k for k in x if not k is None]) - idx = [1 if k == m else 0 for k in x] - ranks = [r if idx[k] else ranks[k] for k in range(l)] - r += sum(idx) - x = [None if idx[k] else x[k] for k in range(l)] - return ranks + orig = np.ravel(np.asarray(a)) + arr = orig[~np.isnan(orig)] + sorter = np.argsort(arr, kind="quicksort") + inv = np.empty(sorter.size, dtype=np.intp) + inv[sorter] = np.arange(sorter.size, dtype=np.intp) + + arr = arr[sorter] + obs = np.r_[True, arr[1:] != arr[:-1]] + dense = obs.cumsum()[inv] + + count = np.r_[np.nonzero(obs)[0], len(obs)] + ranks = np.zeros_like(orig) + ranks[~np.isnan(orig)] = count[dense - 1] + 1 + ranks[np.isnan(orig)] = np.max(ranks) + 1 + return list(ranks) |
