diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
| commit | 3ed51bb4ac5db6dfec92d10d6e302381c27849c8 (patch) | |
| tree | 24d1df8d8ef4bdf003ca35ae310a5804c28b455e /gensvm/cython_wrapper | |
| parent | Bugfix and test for predict method of GridSearch (diff) | |
| download | pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.tar.gz pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.zip | |
Add support for specifying sample weights (fixes #2)
Diffstat (limited to 'gensvm/cython_wrapper')
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 1 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 4 |
2 files changed, 5 insertions, 0 deletions
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index 5b097a6..d19d442 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -115,6 +115,7 @@ cdef extern from "gensvm_helper.c": double, double, double, double, long, long) void set_seed_model(GenModel *, double, double, double, double, int, int, double, double, double, double, long, long, char *, long, long) + void set_raw_weights(GenModel *, char *, int) void set_data(GenData *, char *, char *, np.npy_intp *, long) void set_task(GenTask *, int, GenData *, int, double, double, double, double, double, int, double, double, double, long) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 5c58fee..f98341b 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -33,6 +33,7 @@ def train_wrap( double kappa=0.0, double epsilon=1e-6, int weight_idx=1, + np.ndarray[np.float64_t, ndim=1, mode='c'] raw_weights=None, str kernel='linear', double gamma=1.0, double coef=0.0, @@ -74,6 +75,9 @@ def train_wrap( gensvm_free_model(seed_model) seed_model = NULL + if not raw_weights is None: + set_raw_weights(model, raw_weights.data, n_obs) + # Check the parameters error_msg = check_model(model) if error_msg: |
