aboutsummaryrefslogtreecommitdiff
path: root/gensvm
diff options
context:
space:
mode:
authorGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 17:12:02 -0500
committerGertjan van den Burg <gertjanvandenburg@gmail.com>2019-03-06 17:12:02 -0500
commit93e953a930131a00f0329fdbbe8cfa575ca4c2ec (patch)
tree49a38883f5f959547769728fd8fb24b97ac90a6f /gensvm
parentTravis (#4) (diff)
downloadpygensvm-93e953a930131a00f0329fdbbe8cfa575ca4c2ec.tar.gz
pygensvm-93e953a930131a00f0329fdbbe8cfa575ca4c2ec.zip
Check to make sure predict arrays are contiguous
Diffstat (limited to 'gensvm')
-rw-r--r--gensvm/core.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/gensvm/core.py b/gensvm/core.py
index ce416d0..c35c18e 100644
--- a/gensvm/core.py
+++ b/gensvm/core.py
@@ -14,7 +14,7 @@ import warnings
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
from sklearn.preprocessing import LabelEncoder
-from sklearn.utils import check_X_y, check_random_state
+from sklearn.utils import check_array, check_X_y, check_random_state
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import check_is_fitted
@@ -370,6 +370,13 @@ class GenSVM(BaseEstimator, ClassifierMixin):
"Test and training data should have the same number of features"
)
+ # make sure arrays are C-contiguous
+ X = check_array(X, accept_sparse=False, dtype=np.float64, order="C")
+ if not trainX is None:
+ trainX = check_array(
+ trainX, accept_sparse=False, dtype=np.float64, order="C"
+ )
+
V = self.combined_coef_
if self.kernel == "linear":
predictions = wrapper.predict_wrap(X, V)