aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 22:28:49 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 22:28:49 -0500
commit3ed51bb4ac5db6dfec92d10d6e302381c27849c8 (patch)
tree24d1df8d8ef4bdf003ca35ae310a5804c28b455e /gensvm
parentBugfix and test for predict method of GridSearch (diff)
downloadpygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.tar.gz
pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.zip
Add support for specifying sample weights (fixes #2)
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/core.py34
-rw-r--r--gensvm/cython_wrapper/wrapper.pxd1
-rw-r--r--gensvm/cython_wrapper/wrapper.pyx4
3 files changed, 36 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index c35c18e..dab2368 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -30,6 +30,7 @@ def _fit_gensvm(
kappa,
epsilon,
weights,
+ sample_weight,
kernel,
gamma,
coef,
@@ -48,7 +49,7 @@ def _fit_gensvm(
wrapper.set_verbosity_wrap(verbose)
# convert the weight index
- weight_idx = 1 if weights == "unit" else 2
+ weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights]
# run the actual training
raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap(
@@ -60,6 +61,7 @@ def _fit_gensvm(
kappa,
epsilon,
weight_idx,
+ sample_weight,
kernel,
gamma,
coef,
@@ -116,6 +118,10 @@ class GenSVM(BaseEstimator, ClassifierMixin):
Type of sample weights to use. Options are 'unit' for unit weights and
'group' for group size correction weights (equation 4 in the paper).
+ It is also possible to provide an explicit vector of sample weights
+ through the :func:`~GenSVM.fit` method. If so, it will override the
+ setting provided here.
+
kernel : string, optional (default='linear')
Specify the kernel type to use in the classifier. It must be one of
'linear', 'poly', 'rbf', or 'sigmoid'.
@@ -242,7 +248,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
self.random_state = random_state
self.max_iter = max_iter
- def fit(self, X, y, seed_V=None):
+ def fit(self, X, y, sample_weight=None, seed_V=None):
"""Fit the GenSVM model on the given data
The model can be fit with or without a seed matrix (``seed_V``). This
@@ -257,6 +263,11 @@ class GenSVM(BaseEstimator, ClassifierMixin):
y : array, shape = (n_observations, )
The label vector, labels can be numbers or strings.
+ sample_weight : array, shape = (n_observations, )
+ Array of weights that are assigned to individual samples. If not
+ provided, then the weight specification in the constructor is used
+ ('unit' or 'group').
+
seed_V : array, shape = (n_features+1, n_classes-1), optional
Seed coefficient array to use as a warm start for the optimization.
It can for instance be the :attr:`combined_coef_
@@ -281,6 +292,21 @@ class GenSVM(BaseEstimator, ClassifierMixin):
X, y_org = check_X_y(
X, y, accept_sparse=False, dtype=np.float64, order="C"
)
+ if not sample_weight is None:
+ sample_weight = check_array(
+ sample_weight,
+ accept_sparse=False,
+ ensure_2d=False,
+ dtype=np.float64,
+ order="C",
+ )
+ if not len(sample_weight) == X.shape[0]:
+ raise ValueError(
+ "sample weight array must have the same number of observations as X"
+ )
+ weights = "raw"
+ else:
+ weights = self.weights
y_type = type_of_target(y_org)
if y_type not in ["binary", "multiclass"]:
@@ -330,7 +356,8 @@ class GenSVM(BaseEstimator, ClassifierMixin):
self.lmd,
self.kappa,
self.epsilon,
- self.weights,
+ weights,
+ sample_weight,
self.kernel,
gamma,
self.coef,
@@ -358,6 +385,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
Returns
-------
y_pred : array, shape = (n_samples, )
+ Predicted class labels of the data in X.
"""
diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd
index 5b097a6..d19d442 100644
--- a/gensvm/cython_wrapper/wrapper.pxd
+++ b/gensvm/cython_wrapper/wrapper.pxd
@@ -115,6 +115,7 @@ cdef extern from "gensvm_helper.c":
double, double, double, double, long, long)
void set_seed_model(GenModel *, double, double, double, double, int, int,
double, double, double, double, long, long, char *, long, long)
+ void set_raw_weights(GenModel *, char *, int)
void set_data(GenData *, char *, char *, np.npy_intp *, long)
void set_task(GenTask *, int, GenData *, int, double, double, double,
double, double, int, double, double, double, long)
diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx
index 5c58fee..f98341b 100644
--- a/gensvm/cython_wrapper/wrapper.pyx
+++ b/gensvm/cython_wrapper/wrapper.pyx
@@ -33,6 +33,7 @@ def train_wrap(
double kappa=0.0,
double epsilon=1e-6,
int weight_idx=1,
+ np.ndarray[np.float64_t, ndim=1, mode='c'] raw_weights=None,
str kernel='linear',
double gamma=1.0,
double coef=0.0,
@@ -74,6 +75,9 @@ def train_wrap(
gensvm_free_model(seed_model)
seed_model = NULL
+ if not raw_weights is None:
+ set_raw_weights(model, raw_weights.data, n_obs)
+
# Check the parameters
error_msg = check_model(model)
if error_msg: