diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 17:12:02 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 17:12:02 -0500 |
| commit | 93e953a930131a00f0329fdbbe8cfa575ca4c2ec (patch) | |
| tree | 49a38883f5f959547769728fd8fb24b97ac90a6f /gensvm | |
| parent | Travis (#4) (diff) | |
| download | pygensvm-93e953a930131a00f0329fdbbe8cfa575ca4c2ec.tar.gz pygensvm-93e953a930131a00f0329fdbbe8cfa575ca4c2ec.zip | |
Check to make sure predict arrays are contiguous
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/core.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/gensvm/core.py b/gensvm/core.py index ce416d0..c35c18e 100644 --- a/gensvm/core.py +++ b/gensvm/core.py @@ -14,7 +14,7 @@ import warnings from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.exceptions import ConvergenceWarning, FitFailedWarning from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_X_y, check_random_state +from sklearn.utils import check_array, check_X_y, check_random_state from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted @@ -370,6 +370,13 @@ class GenSVM(BaseEstimator, ClassifierMixin): "Test and training data should have the same number of features" ) + # make sure arrays are C-contiguous + X = check_array(X, accept_sparse=False, dtype=np.float64, order="C") + if not trainX is None: + trainX = check_array( + trainX, accept_sparse=False, dtype=np.float64, order="C" + ) + V = self.combined_coef_ if self.kernel == "linear": predictions = wrapper.predict_wrap(X, V) |
