diff options
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index ce416d0..c35c18e 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -14,7 +14,7 @@ import warnings from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.exceptions import ConvergenceWarning, FitFailedWarning from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_X_y, check_random_state +from sklearn.utils import check_array, check_X_y, check_random_state from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted @@ -370,6 +370,13 @@ class GenSVM(BaseEstimator, ClassifierMixin): "Test and training data should have the same number of features" ) + # make sure arrays are C-contiguous + X = check_array(X, accept_sparse=False, dtype=np.float64, order="C") + if not trainX is None: + trainX = check_array( + trainX, accept_sparse=False, dtype=np.float64, order="C" + ) + V = self.combined_coef_ if self.kernel == "linear": predictions = wrapper.predict_wrap(X, V) |
