aboutsummaryrefslogtreecommitdiff
path: root/src/wrapper.pyx
blob: 1d84b59f8b761019c6438b9f1ef69115f203e3d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Wrapper for GenSVM

Not implemented yet:
    - vector of instance weights
    - class weights
    - seed model
    - max_iter = -1 for unlimited

"""

from __future__ import print_function

import numpy as np
cimport numpy as np

cimport wrapper

np.import_array()

GENSVM_KERNEL_TYPES = ["linear", "poly", "rbf", "sigmoid"]

def train_wrap(
        np.ndarray[np.float64_t, ndim=2, mode='c'] X,
        np.ndarray[np.int_t, ndim=1, mode='c'] y,
        double p=1.0,
        double lmd=pow(2, -8),
        double kappa=0.0,
        double epsilon=1e-6,
        int weight_idx=1,
        str kernel='linear',
        double gamma=1.0,
        double coef=0.0,
        double degree=2.0,
        double kernel_eigen_cutoff=1e-8,
        int max_iter=100000000,
        int random_seed=-1):
    """
    """

    # Initialize model and data
    cdef GenModel *model = gensvm_init_model()
    cdef GenData *data = gensvm_init_data()
    cdef long n_obs
    cdef long n_var
    cdef long n_class

    # get the kernel index
    kernel_index = GENSVM_KERNEL_TYPES.index(kernel)

    # get the number of classes
    classes = np.unique(y)
    n_obs = X.shape[0]
    n_var = X.shape[1]
    n_class = classes.shape[0]

    # Set the data
    set_data(data, X.data, y.data, X.shape, n_class)

    # Set the model
    set_model(model, p, lmd, kappa, epsilon, weight_idx, kernel_index, degree, 
            gamma, coef, kernel_eigen_cutoff, max_iter, random_seed)

    # Check the parameters
    error_msg = check_model(model)
    if error_msg:
        gensvm_free_model(model)
        free_data(data)
        error_repl = error_msg.decode('utf-8')
        raise ValueError(error_repl)

    # Do the actual training
    with nogil:
        gensvm_train(model, data, NULL)

    # copy the results
    cdef np.ndarray[np.float64_t, ndim=2, mode='c'] V
    V = np.empty((n_var+1, n_class-1))
    copy_V(V.data, model)

    # get other results from model
    iter_count = get_iter_count(model)
    training_error = get_training_error(model)
    fit_status = get_status(model)
    n_SV = gensvm_num_sv(model)

    # free model and data
    gensvm_free_model(model);
    free_data(data);

    return (V, n_SV, iter_count, training_error, fit_status)

def predict_wrap(
        np.ndarray[np.float64_t, ndim=2, mode='c'] X,
        np.ndarray[np.float64_t, ndim=2, mode='c'] V
        ):
    """
    """

    cdef long n_test_obs
    cdef long n_var
    cdef long n_class

    n_test_obs = X.shape[0]
    n_var = X.shape[1]
    n_class = V.shape[1] + 1

    # output vector
    cdef np.ndarray[np.int_t, ndim=1, mode='c'] predictions
    predictions = np.empty((n_test_obs, ), dtype=np.int)

    # do the prediction
    with nogil:
        gensvm_predict(X.data, V.data, n_test_obs, n_var, n_class, 
                predictions.data)

    return predictions

def set_verbosity_wrap(int verbosity):
    """
    Control verbosity of gensvm library
    """
    set_verbosity(verbosity)