aboutsummaryrefslogtreecommitdiff
path: root/gensvm/util.py
blob: 40d0eb118097883adc100ef99dcb27734b6680ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-

"""
Utility functions for GenSVM

"""


import numpy as np

from sklearn.exceptions import NotFittedError

def get_ranks(a):
    """
    Rank data in an array. Low values get a small rank number. Ties are broken 
    by assigning the lowest value (this corresponds to ``rankdata(a, 
    method='min')`` in SciPy.

    Examples
    --------
    >>> x = [7, 0.1, 0.5, 0.1, 10, 100, 200]
    >>> get_ranks(x)
    [4, 1, 3, 1, 5, 6, 7]

    """
    orig = np.ravel(np.asarray(a))
    arr = orig[~np.isnan(orig)]
    sorter = np.argsort(arr, kind="quicksort")
    inv = np.empty(sorter.size, dtype=np.intp)
    inv[sorter] = np.arange(sorter.size, dtype=np.intp)

    arr = arr[sorter]
    obs = np.r_[True, arr[1:] != arr[:-1]]
    dense = obs.cumsum()[inv]

    count = np.r_[np.nonzero(obs)[0], len(obs)]
    ranks = np.zeros_like(orig)
    ranks[~np.isnan(orig)] = count[dense - 1] + 1
    ranks[np.isnan(orig)] = np.max(ranks) + 1
    return list(ranks)


def check_is_fitted(estimator, attribute):
    msg = (
        "This %(name)s instance is not fitted yet. Call 'fit' "
        "with appropriate arguments before using this estimator."
    )
    if not hasattr(estimator, "fit"):
        raise TypeError("%s is not an estimator instance" % (estimator))

    if not hasattr(estimator, attribute):
        raise NotFittedError(msg % {"name": type(estimator).__name__})