diff options
| author | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-05 19:21:59 +0100 |
|---|---|---|
| committer | Gertjan van den Burg <burg@ese.eur.nl> | 2016-12-05 19:21:59 +0100 |
| commit | 3713989f5ea8747c2afca4d35e5f2da746f25b24 (patch) | |
| tree | 902398c9d05c122612bcc76fdd7bc82f04800b6f /tests/aux/test_kernel_cross.m | |
| parent | further unit tests for kernel module (diff) | |
| download | gensvm-3713989f5ea8747c2afca4d35e5f2da746f25b24.tar.gz gensvm-3713989f5ea8747c2afca4d35e5f2da746f25b24.zip | |
add octave testfiles to git
Diffstat (limited to 'tests/aux/test_kernel_cross.m')
| -rw-r--r-- | tests/aux/test_kernel_cross.m | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/tests/aux/test_kernel_cross.m b/tests/aux/test_kernel_cross.m new file mode 100644 index 0000000..5777afd --- /dev/null +++ b/tests/aux/test_kernel_cross.m @@ -0,0 +1,70 @@ + +function test_kernel_cross(kerneltype) + +clc; +rand('state', 123456); +more off; + +n_1 = 10; +n_2 = 5; +m = 3; + +X_1 = rand(n_1, m); +Z_1 = [ones(n_1, 1), X_1]; +X_2 = rand(n_2, m); +Z_2 = [ones(n_2, 1), X_2]; + +for ii=1:n_1 + for jj=1:m+1 + fprintf("matrix_set(data_1->RAW, data_1->m+1, %i, %i, %.16f);\n", ii-1, jj-1, Z_1(ii, jj)); + end +end + +fprintf('\n'); +for ii=1:n_2 + for jj=1:m+1 + fprintf("matrix_set(data_2->RAW, data_2->m+1, %i, %i, %.16f);\n", ii-1, jj-1, Z_2(ii, jj)); + end +end + + +K = zeros(n_2, n_1); +if strcmp(kerneltype, 'poly') + # Polynomial kernel + # (gamma * <x_1, x_2> + c)^d + gamma = 1.5; + c = 3.0; + d = 1.78; + + for ii=1:n_2 + for jj=1:n_1 + K(ii, jj) = (gamma * (X_2(ii, :) * X_1(jj, :)') + c)^d; + end + end +elseif strcmp(kerneltype, 'rbf') + # RBF kernel + # exp(-gamma * norm(x1 - x2)^2) + gamma = 0.348 + for ii=1:n_2 + for jj=1:n_1 + K(ii, jj) = exp(-gamma * sum((X_2(ii, :) - X_1(jj, :)).^2)); + end + end +elseif strcmp(kerneltype, 'sigmoid') + # Sigmoid kernel + # tanh(gamma * <x_1, x_2> + c) + gamma = 1.23; + c = 1.6; + for ii=1:n_2 + for jj=1:n_1 + K(ii, jj) = tanh(gamma * (X_2(ii, :) * X_1(jj, :)') + c); + end + end +end + +fprintf('\n'); +for ii=1:n_2 + for jj=1:n_1 + fprintf("mu_assert(fabs(matrix_get(K2, data_1->n, %i, %i) -\n %.16f) < eps,\n\"Incorrect K2 at %i, %i\");\n", ii-1, jj-1, K(ii, jj), ii-1, jj-1); + end +end
\ No newline at end of file |
