"""2D neural field with empirical lateral connection parameters."""
from typing import Optional, Tuple
import numpy as np
from scipy.fft import fft2, ifft2
from mercurial.core.wilson_cowan import WilsonCowanPopulation
[docs]
class MexicanHatKernel:
"""Mexican‑hat kernel with empirically derived amplitudes and widths."""
[docs]
def __init__(
self,
size: int,
dx: float,
A_e: float = 1.0,
sigma_e: float = 0.5,
A_i: float = 0.8,
sigma_i: float = 1.5,
):
"""
Parameters
----------
size : int
Kernel size (odd, typically 2*radius+1).
dx : float
Grid spacing (mm).
A_e, sigma_e : float
Excitatory amplitude and width (mm).
A_i, sigma_i : float
Inhibitory amplitude and width (mm).
"""
self.size = size
self.dx = dx
self.A_e = A_e
self.sigma_e = sigma_e
self.A_i = A_i
self.sigma_i = sigma_i
self._compute_kernel()
def _compute_kernel(self):
half = self.size // 2
x = np.arange(-half, half + 1) * self.dx
y = np.arange(-half, half + 1) * self.dx
X, Y = np.meshgrid(x, y)
r2 = X**2 + Y**2
self.kernel = self.A_e * np.exp(-r2 / (2 * self.sigma_e**2)) - self.A_i * np.exp(
-r2 / (2 * self.sigma_i**2)
)
[docs]
def get_kernel(self) -> np.ndarray:
return self.kernel
[docs]
def get_ft(self, shape: Tuple[int, int]) -> np.ndarray:
pad_h = (shape[0] - self.size) // 2
pad_w = (shape[1] - self.size) // 2
padded = np.zeros(shape)
padded[pad_h : pad_h + self.size, pad_w : pad_w + self.size] = self.kernel
return fft2(padded)
[docs]
class NeuralField2D:
"""
2D neural field with empirical lateral connectivity parameters.
"""
[docs]
def __init__(
self,
nx: int,
ny: int,
dx: float = 0.5,
dy: float = 0.5,
wc_params: Optional[dict] = None,
kernel_ee: Optional[MexicanHatKernel] = None,
kernel_ie: Optional[MexicanHatKernel] = None,
):
"""
Parameters
----------
nx, ny : int
Grid dimensions.
dx, dy : float
Spatial step (mm). Default 0.5 mm approximates cortical column spacing.
wc_params : dict, optional
Parameters for WilsonCowanPopulation (per point).
kernel_ee, kernel_ie : MexicanHatKernel, optional
If None, defaults use empirical values (A_e=1.0, σ_e=0.5 mm, A_i=0.8, σ_i=1.5 mm).
"""
self.nx = nx
self.ny = ny
self.dx = dx
self.dy = dy
self.wc_params = wc_params or {}
self.kernel_ee = kernel_ee or MexicanHatKernel(size=15, dx=dx)
self.kernel_ie = kernel_ie or MexicanHatKernel(
size=15, dx=dx, A_e=0.0, A_i=0.5, sigma_i=1.5
)
self.ft_ee = self.kernel_ee.get_ft((nx, ny))
self.ft_ie = self.kernel_ie.get_ft((nx, ny))
self.wc_populations = [WilsonCowanPopulation(**self.wc_params) for _ in range(nx * ny)]
self.E = np.zeros((nx, ny))
self.I = np.zeros((nx, ny))
self._init_resting_state()
self.E_history = []
self.I_history = []
def _init_resting_state(self):
rng = np.random.default_rng()
self.E = 0.05 + 0.02 * rng.normal(size=(self.nx, self.ny))
self.I = 0.05 + 0.02 * rng.normal(size=(self.nx, self.ny))
self.E = np.clip(self.E, 0.0, 1.0)
self.I = np.clip(self.I, 0.0, 1.0)
def _spatial_convolution(self, field: np.ndarray, ft_kernel: np.ndarray) -> np.ndarray:
return np.real(ifft2(fft2(field) * ft_kernel))
[docs]
def step(
self, dt: float, P_ext: Optional[np.ndarray] = None, Q_ext: Optional[np.ndarray] = None
) -> None:
if P_ext is None:
P_ext = np.zeros((self.nx, self.ny))
if Q_ext is None:
Q_ext = np.zeros((self.nx, self.ny))
conv_ee = self._spatial_convolution(self.E, self.ft_ee)
conv_ie = self._spatial_convolution(self.E, self.ft_ie)
flat_idx = np.arange(self.nx * self.ny)
E_flat = self.E.flatten()
I_flat = self.I.flatten()
P_flat = P_ext.flatten()
Q_flat = Q_ext.flatten()
conv_ee_flat = conv_ee.flatten()
conv_ie_flat = conv_ie.flatten()
new_E = np.zeros_like(E_flat)
new_I = np.zeros_like(I_flat)
for i in flat_idx:
wc = self.wc_populations[i]
input_e = wc.w_ee * E_flat[i] - wc.w_ei * I_flat[i] + P_flat[i] + conv_ee_flat[i]
S_e = wc.sigmoid(input_e, wc.a_e, wc.theta_e)
dE = (-E_flat[i] + (1 - wc.r_e * E_flat[i]) * S_e) / wc.tau_e
input_i = wc.w_ie * E_flat[i] - wc.w_ii * I_flat[i] + Q_flat[i] + conv_ie_flat[i]
S_i = wc.sigmoid(input_i, wc.a_i, wc.theta_i)
dI = (-I_flat[i] + (1 - wc.r_i * I_flat[i]) * S_i) / wc.tau_i
new_E[i] = E_flat[i] + dE * dt
new_I[i] = I_flat[i] + dI * dt
self.E = np.clip(new_E.reshape(self.nx, self.ny), 0.0, 1.0)
self.I = np.clip(new_I.reshape(self.nx, self.ny), 0.0, 1.0)
self.E_history.append(self.E.copy())
self.I_history.append(self.I.copy())
[docs]
def get_phase_field(self) -> np.ndarray:
return np.angle(fft2(self.E))
[docs]
def order_parameter(self) -> float:
phases = self.get_phase_field()
z = np.mean(np.exp(1j * phases))
return np.abs(z)