From 47116a4682edb1f22d00da06802cc3eff40bf5bd Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Thu, 30 May 2019 18:39:05 +0100 Subject: Update documentation --- docs/cls_gridsearch.rst | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) (limited to 'docs/cls_gridsearch.rst') diff --git a/docs/cls_gridsearch.rst b/docs/cls_gridsearch.rst index 8708123..6a2c05e 100644 --- a/docs/cls_gridsearch.rst +++ b/docs/cls_gridsearch.rst @@ -1,5 +1,5 @@ -.. py:class:: GenSVMGridSearchCV(param_grid, scoring=None, iid=True, cv=None, refit=True, verbose=0, return_train_score=True) +.. py:class:: GenSVMGridSearchCV(param_grid='tiny', scoring=None, iid=True, cv=None, refit=True, verbose=0, return_train_score=True) :noindex: :module: gensvm.gridsearch @@ -17,10 +17,15 @@ was needed to benefit from the fast low-level C implementation of grid search in the GenSVM library. - :param param_grid: Dictionary of parameter names (strings) as keys and lists of parameter - settings to evaluate as values, or a list of such dicts. The GenSVM - model will be evaluated at all combinations of the parameters. - :type param_grid: dict or list of dicts + :param param_grid: If a string, it must be either 'tiny', 'small', or 'full' to load the + predefined parameter grids (see the functions :func:`load_grid_tiny`, + :func:`load_grid_small`, and :func:`load_grid_full`). + + Otherwise, a dictionary of parameter names (strings) as keys and lists + of parameter settings to evaluate as values, or a list of such dicts. + The GenSVM model will be evaluated at all combinations of the + parameters. + :type param_grid: string, dict, or list of dicts :param scoring: A single string (see :ref:`scoring_parameter`) or a callable (see :ref:`scoring`) to evaluate the predictions on the test set. @@ -40,7 +45,7 @@ :param cv: Determines the cross-validation splitting strategy. Possible inputs for cv are: - - None, to use the default 3-fold cross validation, + - None, to use the default 5-fold cross validation, - integer, to specify the number of folds in a `(Stratified)KFold`, - An object to be used as a cross-validation generator. - An iterable yielding train, test splits. @@ -51,6 +56,12 @@ Refer to the `scikit-learn User Guide on cross validation`_ for the various strategies that can be used here. + + NOTE: At the moment, the ShuffleSplit and StratifiedShuffleSplit are + not supported in this class. If you need these, you can use the GenSVM + classifier directly with the GridSearchCV object from scikit-learn. + (these methods require significant changes in the low-level library + before they can be supported). :type cv: int, cross-validation generator or an iterable, optional :param refit: Refit the GenSVM estimator with the best found parameters on the whole dataset. @@ -240,7 +251,7 @@ :rtype: object - .. py:method:: GenSVMGridSearchCV.predict(X) + .. py:method:: GenSVMGridSearchCV.predict(X, trainX=None) :noindex: :module: gensvm.gridsearch @@ -249,6 +260,9 @@ :param X: Test data, where n_samples is the number of observations and n_features is the number of features. :type X: array-like, shape = (n_samples, n_features) + :param trainX: Only for nonlinear prediction with kernels: the training data used + to train the model. + :type trainX: array, shape = [n_train_samples, n_features] :returns: **y_pred** -- Predicted class labels of the data in X. :rtype: array-like, shape = (n_samples, ) -- cgit v1.2.3