diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 22:28:49 -0500 |
| commit | 3ed51bb4ac5db6dfec92d10d6e302381c27849c8 (patch) | |
| tree | 24d1df8d8ef4bdf003ca35ae310a5804c28b455e /gensvm | |
| parent | Bugfix and test for predict method of GridSearch (diff) | |
| download | pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.tar.gz pygensvm-3ed51bb4ac5db6dfec92d10d6e302381c27849c8.zip | |
Add support for specifying sample weights (fixes #2)
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 34 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pxd | 1 | ||||
| -rw-r--r-- | gensvm/cython_wrapper/wrapper.pyx | 4 |
3 files changed, 36 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index c35c18e..dab2368 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -30,6 +30,7 @@ def _fit_gensvm( kappa, epsilon, weights, + sample_weight, kernel, gamma, coef, @@ -48,7 +49,7 @@ def _fit_gensvm( wrapper.set_verbosity_wrap(verbose) # convert the weight index - weight_idx = 1 if weights == "unit" else 2 + weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights] # run the actual training raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap( @@ -60,6 +61,7 @@ def _fit_gensvm( kappa, epsilon, weight_idx, + sample_weight, kernel, gamma, coef, @@ -116,6 +118,10 @@ class GenSVM(BaseEstimator, ClassifierMixin): Type of sample weights to use. Options are 'unit' for unit weights and 'group' for group size correction weights (equation 4 in the paper). + It is also possible to provide an explicit vector of sample weights + through the :func:`~GenSVM.fit` method. If so, it will override the + setting provided here. + kernel : string, optional (default='linear') Specify the kernel type to use in the classifier. It must be one of 'linear', 'poly', 'rbf', or 'sigmoid'. @@ -242,7 +248,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): self.random_state = random_state self.max_iter = max_iter - def fit(self, X, y, seed_V=None): + def fit(self, X, y, sample_weight=None, seed_V=None): """Fit the GenSVM model on the given data The model can be fit with or without a seed matrix (``seed_V``). This @@ -257,6 +263,11 @@ class GenSVM(BaseEstimator, ClassifierMixin): y : array, shape = (n_observations, ) The label vector, labels can be numbers or strings. + sample_weight : array, shape = (n_observations, ) + Array of weights that are assigned to individual samples. If not + provided, then the weight specification in the constructor is used + ('unit' or 'group'). + seed_V : array, shape = (n_features+1, n_classes-1), optional Seed coefficient array to use as a warm start for the optimization. It can for instance be the :attr:`combined_coef_ @@ -281,6 +292,21 @@ class GenSVM(BaseEstimator, ClassifierMixin): X, y_org = check_X_y( X, y, accept_sparse=False, dtype=np.float64, order="C" ) + if not sample_weight is None: + sample_weight = check_array( + sample_weight, + accept_sparse=False, + ensure_2d=False, + dtype=np.float64, + order="C", + ) + if not len(sample_weight) == X.shape[0]: + raise ValueError( + "sample weight array must have the same number of observations as X" + ) + weights = "raw" + else: + weights = self.weights y_type = type_of_target(y_org) if y_type not in ["binary", "multiclass"]: @@ -330,7 +356,8 @@ class GenSVM(BaseEstimator, ClassifierMixin): self.lmd, self.kappa, self.epsilon, - self.weights, + weights, + sample_weight, self.kernel, gamma, self.coef, @@ -358,6 +385,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): Returns ------- y_pred : array, shape = (n_samples, ) + Predicted class labels of the data in X. """ diff --git a/gensvm/cython_wrapper/wrapper.pxd b/gensvm/cython_wrapper/wrapper.pxd index 5b097a6..d19d442 100644 --- a/gensvm/cython_wrapper/wrapper.pxd +++ b/gensvm/cython_wrapper/wrapper.pxd @@ -115,6 +115,7 @@ cdef extern from "gensvm_helper.c": double, double, double, double, long, long) void set_seed_model(GenModel *, double, double, double, double, int, int, double, double, double, double, long, long, char *, long, long) + void set_raw_weights(GenModel *, char *, int) void set_data(GenData *, char *, char *, np.npy_intp *, long) void set_task(GenTask *, int, GenData *, int, double, double, double, double, double, int, double, double, double, long) diff --git a/gensvm/cython_wrapper/wrapper.pyx b/gensvm/cython_wrapper/wrapper.pyx index 5c58fee..f98341b 100644 --- a/gensvm/cython_wrapper/wrapper.pyx +++ b/gensvm/cython_wrapper/wrapper.pyx @@ -33,6 +33,7 @@ def train_wrap( double kappa=0.0, double epsilon=1e-6, int weight_idx=1, + np.ndarray[np.float64_t, ndim=1, mode='c'] raw_weights=None, str kernel='linear', double gamma=1.0, double coef=0.0, @@ -74,6 +75,9 @@ def train_wrap( gensvm_free_model(seed_model) seed_model = NULL + if not raw_weights is None: + set_raw_weights(model, raw_weights.data, n_obs) + # Check the parameters error_msg = check_model(model) if error_msg: |
