aboutsummaryrefslogtreecommitdiff
path: root/gensvm/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'gensvm/core.py')
-rw-r--r--gensvm/core.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index f0e7820..bfd5d9a 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -52,7 +52,7 @@ def _fit_gensvm(
weight_idx = {"raw": 0, "unit": 1, "group": 2}[weights]
# run the actual training
- raw_coef_, n_SV_, n_iter_, training_error_, status_ = wrapper.train_wrap(
+ raw_coef_, SVs_, n_iter_, training_error_, status_ = wrapper.train_wrap(
X,
y,
n_class,
@@ -91,8 +91,9 @@ def _fit_gensvm(
coef_ = raw_coef_[1:, :]
intercept_ = raw_coef_[0, :]
+ n_SV_ = sum(SVs_)
- return coef_, intercept_, n_iter_, n_SV_
+ return coef_, intercept_, n_iter_, n_SV_, SVs_
class GenSVM(BaseEstimator, ClassifierMixin):
@@ -177,6 +178,8 @@ class GenSVM(BaseEstimator, ClassifierMixin):
n_support_ : int
The number of support vectors that were found
+ SVs_ : array, shape = [n_observations, ]
+ Index vector that marks the support vectors (1 = SV, 0 = no SV)
See Also
--------
@@ -351,7 +354,7 @@ class GenSVM(BaseEstimator, ClassifierMixin):
)
)
- self.coef_, self.intercept_, self.n_iter_, self.n_support_ = _fit_gensvm(
+ self.coef_, self.intercept_, self.n_iter_, self.n_support_, self.SVs_ = _fit_gensvm(
X,
y,
n_class,