"""Jansen‑Rit neural mass model with empirical parameters and network support."""
from typing import Optional
import numpy as np
from mercurial.params.empirical import JansenRitParams
[docs]
class JansenRitColumn:
"""
Single Jansen‑Rit cortical column.
"""
def __init__(self, params: Optional[JansenRitParams] = None, **kwargs):
if params is None:
params = JansenRitParams()
self.a = kwargs.get("a", params.a)
self.b = kwargs.get("b", params.b)
self.A = kwargs.get("A", params.A)
self.B = kwargs.get("B", params.B)
self.C1 = kwargs.get("C1", params.C1)
self.C2 = kwargs.get("C2", params.C2)
self.C3 = kwargs.get("C3", params.C3)
self.C4 = kwargs.get("C4", params.C4)
self.e0 = kwargs.get("e0", params.e0)
self.v0 = kwargs.get("v0", params.v0)
self.r = kwargs.get("r", params.r)
self.noise_amp = kwargs.get("noise_amp", 0.5)
self.rng = np.random.default_rng()
# State: [y0, y0dot, y1, y1dot, y2, y2dot]
self.state = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + 0.01 * self.rng.random(6)
[docs]
def sigmoid(self, v: float) -> float:
arg = self.r * (self.v0 - v)
if arg > 500:
return 0.0
if arg < -500:
return 2.0 * self.e0
return 2.0 * self.e0 / (1.0 + np.exp(arg))
[docs]
def derivative(self, state: np.ndarray, p_ext: float = 0.0) -> np.ndarray:
y0, y0dot, y1, y1dot, y2, y2dot = state
S_y0 = self.sigmoid(y0)
S_y1_minus_y2 = self.sigmoid(y1 - y2)
dy0 = y0dot
dy0dot = self.A * self.a * S_y1_minus_y2 + p_ext - 2 * self.a * y0dot - self.a**2 * y0
dy1 = y1dot
dy1dot = self.A * self.a * (self.C2 * S_y0 + p_ext) - 2 * self.a * y1dot - self.a**2 * y1
dy2 = y2dot
dy2dot = self.B * self.b * self.C4 * S_y0 - 2 * self.b * y2dot - self.b**2 * y2
return np.array([dy0, dy0dot, dy1, dy1dot, dy2, dy2dot])
[docs]
def step(self, dt: float, p_ext: float = 0.0) -> None:
k1 = self.derivative(self.state, p_ext)
state_pred = self.state + k1 * dt
k2 = self.derivative(state_pred, p_ext)
self.state += (k1 + k2) * dt / 2.0
self.state += self.noise_amp * np.sqrt(dt) * self.rng.normal(size=6)
[docs]
def get_EEG(self) -> float:
return self.state[0]
[docs]
class JansenRitNetwork:
"""
Network of coupled Jansen‑Rit columns.
"""
def __init__(
self,
n_columns: int,
coupling_strength: float = 0.0,
params: Optional[JansenRitParams] = None,
**column_kwargs,
):
self.n = n_columns
self.coupling = coupling_strength
self.columns = [JansenRitColumn(params, **column_kwargs) for _ in range(n_columns)]
[docs]
def step(self, dt: float, external_inputs: Optional[np.ndarray] = None) -> None:
if external_inputs is None:
external_inputs = np.zeros(self.n)
outputs = np.array([col.get_EEG() for col in self.columns])
for i, col in enumerate(self.columns):
# Diffusive coupling (mean field)
coupling_input = self.coupling * np.mean(outputs - outputs[i]) if self.n > 1 else 0.0
total_input = external_inputs[i] + coupling_input
col.step(dt, total_input)
[docs]
def get_EEG_array(self) -> np.ndarray:
return np.array([col.get_EEG() for col in self.columns])
[docs]
def get_mean_EEG(self) -> float:
return np.mean(self.get_EEG_array())