diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-05-30 17:31:17 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-05-30 17:31:17 +0100 |
| commit | 9ad95f7075cf6e03b27207f004fce36909ebbb53 (patch) | |
| tree | 27b6dd7425a19c1cc2d99dc046f46cb9ec55863c /gensvm/core.py | |
| parent | use green for testing instead of nose (diff) | |
| download | pygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.tar.gz pygensvm-9ad95f7075cf6e03b27207f004fce36909ebbb53.zip | |
Add support for returning the support vectors (#5)
Diffstat (limited to 'gensvm/core.py')
| -rw-r--r-- | gensvm/core.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index f0e7820..bfd5d9a 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -52,7 +52,7 @@ def _fit_gensvm( weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights] # run the actual training - raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap( + raw_coef_, SVs_, n_iter_, training_error_, status_ = wrapper.train_wrap( X, y, n_class, @@ -91,8 +91,9 @@ def _fit_gensvm( coef_ = raw_coef_[1:, :] intercept_ = raw_coef_[0, :] + n_SV_ = sum(SVs_) - return coef_, intercept_, n_iter_, n_SV_ + return coef_, intercept_, n_iter_, n_SV_, SVs_ class GenSVM(BaseEstimator, ClassifierMixin): @@ -177,6 +178,8 @@ class GenSVM(BaseEstimator, ClassifierMixin): n_support_ : int The number of support vectors that were found + SVs_ : array, shape = [n_observations, ] + Index vector that marks the support vectors (1 = SV, 0 = no SV) See Also -------- @@ -351,7 +354,7 @@ class GenSVM(BaseEstimator, ClassifierMixin): ) ) - self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm( + self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm( X, y, n_class, |
