Decoding of stimuli from neural data#

Here we will simulate neural data given a ground truth encoding model

and try to decode the stimulus from the data.

# Set up a neural model
from braincoder.models import GaussianPRF
import numpy as np
import pandas as pd
import scipy.stats as ss

# Set up 100 random of PRF parameters
n = 20
n_trials = 50
noise = 1.

mu = np.random.rand(n) * 100
sd = np.random.rand(n) * 45 + 5
amplitude = np.random.rand(n) * 5
baseline = np.random.rand(n) * 2 - 1

parameters = pd.DataFrame({'mu':mu, 'sd':sd, 'amplitude':amplitude, 'baseline':baseline})

# We have a paradigm of random numbers between 0 and 100
paradigm = np.ceil(np.random.rand(n_trials) * 100)

model = GaussianPRF(parameters=parameters)
data = model.simulate(paradigm=paradigm, noise=noise)
# Now we fit back the PRF parameters
from braincoder.optimize import ParameterFitter, ResidualFitter
fitter = ParameterFitter(model, data, paradigm)
mu_grid = np.arange(0, 100, 5)
sd_grid = np.arange(5, 50, 5)

grid_pars = fitter.fit_grid(mu_grid, sd_grid, [1.0], [0.0], use_correlation_cost=True, progressbar=False)
grid_pars = fitter.refine_baseline_and_amplitude(grid_pars)

for par in ['mu', 'sd', 'amplitude', 'baseline']:
    print(f'Correlation grid-fitted parameter and ground truth for *{par}*: {ss.pearsonr(grid_pars[par], parameters[par])[0]:0.2f}')

gd_pars = fitter.fit(init_pars=grid_pars, progressbar=False)

for par in ['mu', 'sd', 'amplitude', 'baseline']:
    print(f'Correlation gradient descent-fitted parameter and ground truth for *{par}*: {ss.pearsonr(grid_pars[par], parameters[par])[0]:0.2f}')
Working with chunk size of 666666
Using correlation cost!

  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00,  8.90it/s]
100%|██████████| 1/1 [00:00<00:00,  8.85it/s]
Correlation grid-fitted parameter and ground truth for *mu*: 0.96
Correlation grid-fitted parameter and ground truth for *sd*: 0.67
Correlation grid-fitted parameter and ground truth for *amplitude*: 0.90
Correlation grid-fitted parameter and ground truth for *baseline*: 0.72
Number of problematic voxels (mask): 0
Number of voxels remaining (mask): 20
Correlation gradient descent-fitted parameter and ground truth for *mu*: 0.96
Correlation gradient descent-fitted parameter and ground truth for *sd*: 0.67
Correlation gradient descent-fitted parameter and ground truth for *amplitude*: 0.90
Correlation gradient descent-fitted parameter and ground truth for *baseline*: 0.72
# Now we fit the covariance matrix
stimulus_range = np.arange(1, 100).astype(np.float32)

model.init_pseudoWWT(stimulus_range=stimulus_range, parameters=gd_pars)
resid_fitter = ResidualFitter(model, data, paradigm, gd_pars)
omega, dof = resid_fitter.fit(progressbar=False)
init_tau: 0.7828128933906555, 1.15021812915802
USING A PSEUDO-WWT!
WWT max: 1122.9664306640625
# Now we simulate unseen test data:
test_paradigm = np.ceil(np.random.rand(n_trials) * 100)
test_data = model.simulate(paradigm=test_paradigm, noise=noise)

# And decode the test paradigm
posterior = model.get_stimulus_pdf(test_data, stimulus_range, model.parameters, omega=omega, dof=dof)
# Finally, we make some plots to see how well the decoder did
import matplotlib.pyplot as plt
import seaborn as sns

tmp = posterior.set_index(pd.Series(test_paradigm, name='ground truth'), append=True).loc[:8].stack().to_frame('p')

g = sns.FacetGrid(tmp.reset_index(), col='frame', col_wrap=3)

g.map(plt.plot, 'stimulus', 'p', color='k')

def test(data, **kwargs):
    plt.axvline(data.mean(), c='k', ls='--', **kwargs)
g.map(test, 'ground truth')
g.set(xlabel='Stimulus value', ylabel='Posterior probability density')
frame = 0, frame = 1, frame = 2, frame = 3, frame = 4, frame = 5, frame = 6, frame = 7, frame = 8
<seaborn.axisgrid.FacetGrid object at 0x2bf534400>
# Let's look at the summary statistics of the posteriors posteriors
def get_posterior_stats(posterior, normalize=True):
    posterior = posterior.copy()
    posterior = posterior.div(np.trapz(posterior, posterior.columns,axis=1), axis=0)

    # Take integral over the posterior to get to the expectation (mean posterior)
    E = np.trapz(posterior*posterior.columns.values[np.newaxis,:], posterior.columns, axis=1)

    # Take the integral over the posterior to get the expectation of the distance to the
    # mean posterior (i.e., standard deviation)
    sd = np.trapz(np.abs(E[:, np.newaxis] - posterior.columns.astype(float).values[np.newaxis, :]) * posterior, posterior.columns, axis=1)

    stats = pd.DataFrame({'E':E, 'sd':sd}, index=posterior.index)
    return stats

posterior_stats = get_posterior_stats(posterior)

# Let's see how far the posterior mean is from the ground truth
plt.errorbar(test_paradigm, posterior_stats['E'],posterior_stats['sd'], fmt='o',)
plt.plot([0, 100], [0,100], c='k', ls='--')

plt.xlabel('Ground truth')
plt.ylabel('Mean posterior')

# Let's see how the error depends on the standard deviation of the posterior
error = test_paradigm - posterior_stats['E']
error_abs = np.abs(error)
error_abs.name = 'error'

sns.lmplot(x='sd', y='error', data=posterior_stats.join(error_abs))

plt.xlabel('Standard deviation of posterior')
plt.ylabel('Objective error')
  • decode
  • decode
Text(28.999999999999986, 0.5, 'Objective error')
# Now, let's try to find the MAP estimate using gradient descent
from braincoder.optimize import StimulusFitter
stimulus_fitter = StimulusFitter(model=model, data=test_data, omega=omega)

# We start with a very coarse grid search, so we are sure we are in the right ballpark
estimated_stimuli_grid = stimulus_fitter.fit_grid(np.arange(1, 100, 5))
# We can then refine the estimate using gradient descent
estimated_stimuli_gd = stimulus_fitter.fit(init_pars=estimated_stimuli_grid, progressbar=False)

# Let's see how well we did
plt.scatter(test_paradigm, estimated_stimuli_grid, alpha=.5, label='MAP (grid search)')
plt.scatter(test_paradigm, estimated_stimuli_gd, alpha=.5, label='MAP (gradient descent)')
plt.scatter(test_paradigm, posterior_stats['E'], alpha=.5, label='Mean posterior')
plt.plot([0, 100], [0,100], c='k', ls='--', label='Identity line')
plt.legend()
# %%
decode
  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<03:13,  5.17it/s]
  5%|▌         | 54/1000 [00:00<00:04, 226.69it/s]
 11%|█         | 109/1000 [00:00<00:02, 347.44it/s]
 17%|█▋        | 166/1000 [00:00<00:01, 424.50it/s]
 22%|██▏       | 219/1000 [00:00<00:01, 458.39it/s]
 27%|██▋       | 272/1000 [00:00<00:01, 480.19it/s]
 32%|███▏      | 323/1000 [00:00<00:01, 484.06it/s]
 37%|███▋      | 374/1000 [00:00<00:01, 484.83it/s]
 42%|████▏     | 424/1000 [00:01<00:01, 485.14it/s]
 47%|████▋     | 474/1000 [00:01<00:01, 474.21it/s]
 52%|█████▏    | 523/1000 [00:01<00:01, 472.40it/s]
 57%|█████▋    | 573/1000 [00:01<00:00, 479.54it/s]
 62%|██████▏   | 622/1000 [00:01<00:00, 481.28it/s]
 67%|██████▋   | 672/1000 [00:01<00:00, 485.35it/s]
 72%|███████▏  | 721/1000 [00:01<00:00, 486.43it/s]
 77%|███████▋  | 771/1000 [00:01<00:00, 490.15it/s]
 82%|████████▏ | 821/1000 [00:01<00:00, 488.91it/s]
 87%|████████▋ | 870/1000 [00:01<00:00, 471.81it/s]
 92%|█████████▏| 918/1000 [00:02<00:00, 464.80it/s]
 96%|█████████▋| 965/1000 [00:02<00:00, 462.06it/s]
100%|██████████| 1000/1000 [00:02<00:00, 448.41it/s]

<matplotlib.legend.Legend object at 0x2c93ddfc0>

Total running time of the script: (0 minutes 16.515 seconds)

Gallery generated by Sphinx-Gallery