Lesson 7: Fitting PRF models to visual space#

"""
Different flavors of visual population receptive field models
==============================================================

In this example script we will try out increasingly complex models for
visual population receptive fields (PRFs). We will start with a simple
Gaussian PRF model, and then add more complexity step by step.

"""

# %%
# Load data
# ---------
# First we load in the data. We will use the Szinte (2024)-dataset.
from braincoder.utils.data import load_szinte2024
import matplotlib.pyplot as plt

data = load_szinte2024()

# This is the visual stimulus ("design matrix")
paradigm = data['stimulus']
grid_coordinates = data['grid_coordinates']

# This is the fMRI response data
d = data['v1_timeseries']
d.index.name = 'frame'
tr = data['tr']


# %%
# Simple 2D Gaussian Recetive Field model
# -------------------------------------
# Now we set up a simple Gaussian PRF model
from braincoder.models import GaussianPRF2DWithHRF
from braincoder.hrf import SPMHRFModel
hrf_model = SPMHRFModel(tr=tr)
model_gauss = GaussianPRF2DWithHRF(data=d, paradigm=paradigm, hrf_model=hrf_model, grid_coordinates=grid_coordinates)

# %%
# And a parameter fitter...
from braincoder.optimize import ParameterFitter
par_fitter = ParameterFitter(model=model_gauss, data=d, paradigm=paradigm)


# %%
# Now we try out a relatively coarse grid search to find the some
# parameters to start the gradient descent from.

import numpy as np
x = np.linspace(-8, 8, 10)
y = np.linspace(-4, 4, 10)
sd = np.linspace(0.1, 4, 10)

# We start the grid search using a correlation cost, so ampltiude
# and baseline do not influence those results.
# We will optimize them later using OLS.
baseline = [0.0]
amplitude = [1.0]

# Now we can do the grid search
pars_gauss_grid = par_fitter.fit_grid(x, y, sd, baseline, amplitude, correlation_cost=True)

# And refine the baseline and amplitude parameters using OLS
pars_gauss_ols = par_fitter.refine_baseline_and_amplitude(pars_gauss_grid)


# %%
# Here we can plot the resulting r2s of the grid search
r2_gauss_ols = par_fitter.get_rsq(pars_gauss_ols)

import seaborn as sns
sns.kdeplot(r2_gauss_ols, shade=True)
sns.despine()
# %%

# %%
# We can substantially improve the fit by using gradient descent optimisation
pars_gauss_gd = par_fitter.fit(init_pars=pars_gauss_ols, max_n_iterations=1000)

# %%
r2_gauss_gd = par_fitter.get_rsq(pars_gauss_gd)
sns.kdeplot(r2_gauss_gd, shade=True)

import pandas as pd
r2 = pd.concat((r2_gauss_ols, r2_gauss_gd), keys=['r2_ols', 'r2_gd'], axis=1)

# %%
# Clearly, the gradient descent optimization improves the fit substantially.
sns.relplot(x='r2_ols', y='r2_gd', data=r2.reset_index(), kind='scatter')
plt.plot([0, 1], [0, 1], 'k--')
#
#  %%

# %%
# Fit HRFs
# --------
# The standard canonical (SPM) HRF we use is often not a great fit to actual
# data. To better account for the HRF. We can optimize the HRFs per voxel.
# We first initialize a GaussianPRF-model with a flexible HRF.
model_hrf = GaussianPRF2DWithHRF(data=d, paradigm=paradigm, hrf_model=hrf_model,
                             grid_coordinates=grid_coordinates, flexible_hrf_parameters=True)

par_fitter_hrf = ParameterFitter(model=model_hrf, data=d, paradigm=paradigm)

# We set hrf_delay and hrf_dispersion to standard values
pars_gauss_gd['hrf_delay'] = 6
pars_gauss_gd['hrf_dispersion'] = 1

pars_gauss_hrf = par_fitter_hrf.fit(init_pars=pars_gauss_gd, max_n_iterations=1000)

# %%
r2_gauss_hrf = par_fitter_hrf.get_rsq(pars_gauss_hrf)

r2 = pd.concat((r2_gauss_gd, r2_gauss_hrf), keys=['r2_gd', 'r2_hrf'], axis=1)
sns.relplot(x='r2_gd', y='r2_hrf', data=r2.reset_index(), kind='scatter')
plt.plot([0, 1], [0, 1], 'k--')

# %%
# Here we plot the predicted time courses of the original model
# and the model with the optimized HRFs for 9 voxels where the fit
# improved the most. You can clearly see that, in general, the
# HRFs have shorter delays than the default setting.
improvement = r2_gauss_hrf - r2_gauss_gd
largest_improvements = improvement.sort_values(ascending=False).index[:9]
pred_gauss_gd = model_gauss.predict(parameters=pars_gauss_gd)
pred_gauss_hrf = model_hrf.predict(parameters=pars_gauss_hrf)
pred = pd.concat((d.loc[:, largest_improvements], pred_gauss_gd.loc[:, largest_improvements], pred_gauss_hrf.loc[:, largest_improvements]), axis=1, keys=['data', 'gauss', 'gauss+hrf'], names=['model'])

#
tmp = pred.stack(['model', 'source']).to_frame('value')
sns.relplot(x='frame', y='value', hue='model', col='source', data=tmp.reset_index(), kind='line', col_wrap=3)


# %%

# %%
# Fit a Difference of Gaussians model
# -----------------------------------
# Now we will try to fit a Difference of Gaussians model. This model
# has two Gaussian receptive fields, one excitatory and one inhibitory.
# The inhibitory receptive field is subtracted from the excitatory one.
# The resulting receptive field is then convolved with the HRF.
from braincoder.models import DifferenceOfGaussiansPRF2DWithHRF
model_dog = DifferenceOfGaussiansPRF2DWithHRF(data=d, paradigm=paradigm, hrf_model=hrf_model,
                                         grid_coordinates=grid_coordinates, flexible_hrf_parameters=True)

pars_dog_init = pars_gauss_hrf.copy()
# This is the relative amplitude of the inhibitory receptive field
# compared to the excitatory one.
pars_dog_init['srf_amplitude'] = 0.1

# This is the relative size of the inhibitory receptive field
# compared to the excitatory one.
pars_dog_init['srf_size'] = 2.

# Let's set up a new parameterfitter 
par_fitter_dog = ParameterFitter(model=model_dog, data=d, paradigm=paradigm)

# Note how, for now, we are not optimizing the HRF parameters.
pars_dog = par_fitter_dog.fit(init_pars=pars_dog_init, max_n_iterations=1000,
                              fixed_pars=['hrf_delay', 'hrf_dispersion'])

# Now we optimize _with_ the HRF parameters
pars_dog_hrf = par_fitter_dog.fit(init_pars=pars_dog, max_n_iterations=1000)

r2_dog_hrf = par_fitter_dog.get_rsq(pars_dog_hrf)

sns.relplot(x='r2_hrf', y='r2_dog_hrf', data=pd.concat((r2_gauss_hrf, r2_dog_hrf), axis=1,
                                                       keys=['r2_hrf', 'r2_dog_hrf']).reset_index(), kind='scatter')
# %%


# %%
# Here, we plot the predicted time courses of the difference-of-gaussians
# model versus the original Gaussian model for the 9 voxels where the fit
# imoproved the most.
improvement = r2_dog_hrf - r2_gauss_hrf
largest_improvements = improvement.sort_values(ascending=False).index[:9]
pred_dog_hrf = model_dog.predict(parameters=pars_dog_hrf)
pred = pd.concat((d.loc[:, largest_improvements], pred_gauss_hrf.loc[:, largest_improvements], pred_dog_hrf.loc[:, largest_improvements]), axis=1, keys=['data', 'gauss+hrf', 'dog+hrf'], names=['model'])

tmp = pred.stack(['model', 'source']).to_frame('value')
sns.relplot(x='frame', y='value', hue='model', col='source', data=tmp.reset_index(), kind='line', col_wrap=3,
            palette=['k'] + sns.color_palette(),
            hue_order=['data', 'gauss+hrf', 'dog+hrf'])


# %%
# Divisve Normalization PRF model
# -------------------------------
# The most complex model we have is the DN-PRF model (Aqil et al., 2021).
# This model has a Gaussian excitatory receptive field, and a Gaussian
# inhibitory receptive field. The excitatory receptive field is divided
# by the sum of the excitatory and inhibitory receptive fields. 
# The resulting receptive field is then convolved with the HRF.
from braincoder.models import DivisiveNormalizationGaussianPRF2DWithHRF
model_dn = DivisiveNormalizationGaussianPRF2DWithHRF(data=d,
                                              paradigm=paradigm,
                                              hrf_model=hrf_model,
                                              grid_coordinates=grid_coordinates,
                                              flexible_hrf_parameters=True)

pars_dn_init = pars_dog_hrf.copy()
pars_dn_init['srf_amplitude'] = 0.01
pars_dn_init['srf_size'] = 2.
pars_dn_init['baseline'] = 0.0
pars_dn_init['neural_baseline'] = 1.0
pars_dn_init['surround_baseline'] = 1.0

par_fitter_dn = ParameterFitter(model=model_dn, data=d, paradigm=paradigm)
# Without HRF
pars_dn = par_fitter_dn.fit(init_pars=pars_dn_init, max_n_iterations=1000, fixed_pars=['hrf_delay', 'hrf_dispersion'])

# With HRF
pars_dn = par_fitter_dn.fit(init_pars=pars_dn, max_n_iterations=1000)

# %%
# Again, let's  plot the R2 improvements
r2_dn = par_fitter_dn.get_rsq(pars_dn)
sns.relplot(x='r2_dog_hrf', y='r2_dn', data=pd.concat((r2_dog_hrf, r2_dn), axis=1,
                                                       keys=['r2_dog_hrf', 'r2_dn']).reset_index(), kind='scatter')

plt.plot([0, 1], [0, 1], 'k--')

# %%
improvement = r2_dn - r2_dog_hrf
largest_improvements = improvement.sort_values(ascending=False).index[:9]

pred_dn = model_dn.predict(parameters=pars_dn)
pred = pd.concat((d.loc[:, largest_improvements], pred_dog_hrf.loc[:, largest_improvements], pred_dn.loc[:, largest_improvements]), axis=1, keys=['data', 'dog+hrf', 'dn+hrf'], names=['model'])

tmp = pred.stack(['model', 'source']).to_frame('value')
sns.relplot(x='frame', y='value', hue='model', col='source', data=tmp.reset_index(), kind='line', col_wrap=3,
            palette=['k'] + sns.color_palette(),
            hue_order=['data', 'dog+hrf', 'dn+hrf'])
# %%


# Decoding
# --------
# We can also use the fitted models to decode the stimulus from the
# fMRI response. Let's compare our simplest model versus our most
# complex model.

# First we fit the noise models
from braincoder.optimize import ResidualFitter, StimulusFitter

# Let's first get grid coordinates and paradigm at a slightly lower resolution
data = load_szinte2024(resize_factor=2.5)
grid_coordinates = data['grid_coordinates']
paradigm = data['stimulus']

best_voxels_gauss = r2_gauss_gd[pars_gauss_gd['sd'] > 0.5].sort_values(ascending=False).index[:200]

model_gauss = GaussianPRF2DWithHRF(data=d[best_voxels_gauss],
                                   hrf_model=hrf_model,
                                   grid_coordinates=grid_coordinates.astype(np.float32),
                                   parameters=pars_gauss_gd.loc[best_voxels_gauss].astype(np.float32))

resid_fitter_gauss = ResidualFitter(model=model_gauss, data=d.loc[:, best_voxels_gauss],
                                    paradigm=paradigm.astype(np.float32), parameters=pars_gauss_gd.loc[best_voxels_gauss].astype(np.float32))
omega_gauss, _ = resid_fitter_gauss.fit()




# %%


# %%
best_voxels_dn = r2_dn[pars_dn['sd'] > 0.5].sort_values(ascending=False).index[:200]

model_dn = DivisiveNormalizationGaussianPRF2DWithHRF(data=d[best_voxels_dn], 
                                              hrf_model=hrf_model,
                                              grid_coordinates=grid_coordinates.astype(np.float32),
                                              parameters=pars_dn.loc[best_voxels_dn].astype(np.float32))

resid_fitter_dn = ResidualFitter(model=model_dn, data=d.loc[:, best_voxels_dn],
                                    paradigm=paradigm, parameters=pars_dn.loc[best_voxels_dn])

omega_dn, _ = resid_fitter_dn.fit()

# %%
# Decoded stimulus: Gaussian model
# ===============================
# Now we can decode the stimulus from the fMRI responses
stim_fitter_gauss = StimulusFitter(model=model_gauss, data=d.loc[:, best_voxels_gauss], omega=omega_gauss)
stim_gauss = stim_fitter_gauss.fit(l2_norm=0.01, learning_rate=0.01, max_n_iterations=1000)

# %%
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def play_reconstruction(reconstructed_stimulus):

  # Here we make a movie of the decoded stimulus
  # Set up a function to draw a single frame
  vmin, vmax = 0.0, np.quantile(reconstructed_stimulus.values.ravel(), 0.99)

  def update(frame):
      plt.clf()  # Clear the current figure
      plt.imshow(reconstructed_stimulus.stack('y').loc[frame].iloc[::-1, :], cmap='viridis', vmin=vmin, vmax=vmax)
      plt.axis('off')
      plt.title(f"Frame {frame}")

  # Create the animation
  fig = plt.figure()
  ani = FuncAnimation(fig, update, frames=range(paradigm.shape[0]), interval=100)

  return HTML(ani.to_html5_video())

play_reconstruction(stim_gauss)


# %%
# Decoded stimulus: DN model
# ==========================
stim_fitter_dn = StimulusFitter(model=model_dn, data=d.loc[:, best_voxels_dn], omega=omega_dn)
stim_dn = stim_fitter_dn.fit(l2_norm=0.01, learning_rate=0.01, max_n_iterations=1000)

# %%
play_reconstruction(stim_dn)
# As you can see, the DN model works a lot better than the Gaussian model. ;)

# %%