clear all
close all
%%%%%%%% PARAMETERS %%%%%%%%%%%%%%%%%%%%%
N = 1500; % number of training data points
N_test = 500; % number of test data points
Nswitch = 500; % switch from model 1 to model 2 after Nswitch iterations
B1 = [1,.0668,-0.4764,0.8070]; % model 1 linear filter
B2 = [1,-.4326,.6656,-.7153]; % model 2 linear filter
f = @(x) tanh(x); % Wiener system nonlinearity
SNR = 0.00000020; % SNR in dB
embedding = 4; % time-embedding
setups{1} = rls(struct('c',0.9,'lambda',.99));
%%%%%%%%%%%%%% PREPARE DATA %%%%%%%%%%%%%%%%%%
% generate Gaussian input data s
s = rand(N+N_test,1);
s_mem = zeros(N+N_test,embedding);
for i = 1:embedding,
s_mem(i:N+N_test,i) = s(1:N+N_test-i+1); % time-embedding
end
s = s_mem(1:N+N_test,:); % input data, stored in columns
s_train = s_mem(1:N,:); % input train data, stored in columns
s_test = s_mem(N+1:N+N_test,:); % input test data, stored in columns
% generate internal data x and output y
X1 = s_mem(1:Nswitch,:)*B1';
X2 = s_mem(Nswitch+1:N,:)*B2';
X = [X1;X2];
Y_nn = f(X); % noiseless Y
vary = var(Y_nn);
noisevar = 10^(-SNR/10)*vary;
noise = sqrt(noisevar)*randn(N,1);
Y = Y_nn + noise; % noisy output data
X_test1 = s_mem(N+1:N+N_test,:)*B1';
X_test2 = s_mem(N+1:N+N_test,:)*B2';
noise_test1 = sqrt(noisevar)*randn(N_test,1);
noise_test2 = sqrt(noisevar)*randn(N_test,1);
Y_test1 = f(X_test1) + noise_test1; % noisy output test data, model 1
Y_test2 = f(X_test2) + noise_test2; % noisy output test data, model 2
%%%%%%%%%%%%%%% RUN ALGORITHMS%%%%%%%%%%%%%%%%%%%%%%%%%
num_setup = length(setups);
MSE = zeros(N,num_setup);
for setup_ind=1:length(setups)
kaf = setups{setup_ind};
for n=1:N,
if n<=Nswitch,
Y_test = Y_test1;
else
Y_test = Y_test2;
end
Y_est = kaf.evaluate(s_test); % test on test set
err = Y_test - Y_est;
MSE(n,setup_ind) = mean(err.^2);
kaf = kaf.train(s_train(n,:),Y(n)); % train with one input-output pair
end
end
%%%%%%%%%%%%%%%% OUTPUT %%%%%%%%%%%%%%%%%%%%%%%
figure
plot(10*log10(MSE),'LineWidth',1);
xlabel('Number of Iteration','fontsize',18),ylabel('Mean Square Error (dB)','fontsize',18);
legend('RLS')
set(legend,'fontsize',22);