"""Hebbian learning and plasticity for neural models (Section 25)."""
from typing import Dict, Optional
import numpy as np
from mercurial.params.empirical import HebbianParams, WilsonCowanParams
[docs]
class PlasticWilsonCowan:
"""
Wilson‑Cowan population with plastic excitatory and inhibitory weights.
Uses empirical parameters for Wilson‑Cowan and Hebbian learning.
"""
[docs]
def __init__(
self,
wc_params: Optional[WilsonCowanParams] = None,
hebb_params: Optional[HebbianParams] = None,
**kwargs,
):
"""
Parameters
----------
wc_params : WilsonCowanParams, optional
Empirical Wilson‑Cowan parameters.
hebb_params : HebbianParams, optional
Empirical Hebbian learning parameters.
**kwargs : overrides for individual parameters.
"""
if wc_params is None:
wc_params = WilsonCowanParams()
if hebb_params is None:
hebb_params = HebbianParams()
# Wilson‑Cowan parameters
self.tau_e = kwargs.get("tau_e", wc_params.tau_e)
self.tau_i = kwargs.get("tau_i", wc_params.tau_i)
self.r_e = kwargs.get("r_e", 0.8)
self.r_i = kwargs.get("r_i", 0.8)
self.w_ee = kwargs.get("w_ee", wc_params.w_ee)
self.w_ei = kwargs.get("w_ei", wc_params.w_ei)
self.w_ie = kwargs.get("w_ie", wc_params.w_ie)
self.w_ii = kwargs.get("w_ii", wc_params.w_ii)
self.a_e = kwargs.get("a_e", wc_params.a_e)
self.theta_e = kwargs.get("theta_e", wc_params.mu_e)
self.a_i = kwargs.get("a_i", wc_params.a_i)
self.theta_i = kwargs.get("theta_i", wc_params.mu_i)
self.sigma = kwargs.get("sigma", 0.01)
# Hebbian learning rates and decay
self.eta_ee = kwargs.get("eta_ee", hebb_params.learning_rate)
self.eta_ei = kwargs.get("eta_ei", hebb_params.learning_rate)
self.eta_ie = kwargs.get("eta_ie", hebb_params.learning_rate)
self.eta_ii = kwargs.get("eta_ii", hebb_params.learning_rate)
self.gamma_ee = kwargs.get("gamma_ee", hebb_params.decay)
self.gamma_ei = kwargs.get("gamma_ei", hebb_params.decay)
self.gamma_ie = kwargs.get("gamma_ie", hebb_params.decay)
self.gamma_ii = kwargs.get("gamma_ii", hebb_params.decay)
self.rng = np.random.default_rng()
self.E = 0.1
self.I = 0.1
[docs]
def sigmoid(self, x: float, a: float, theta: float) -> float:
arg = a * (x - theta)
if arg > 500:
return 1.0
if arg < -500:
return 0.0
return 1.0 / (1.0 + np.exp(-arg))
[docs]
def step(self, dt: float, P_ext: float = 0.0, Q_ext: float = 0.0) -> None:
# Compute derivatives for E and I
input_e = self.w_ee * self.E - self.w_ei * self.I + P_ext
S_e = self.sigmoid(input_e, self.a_e, self.theta_e)
dE = (-self.E + (1 - self.r_e * self.E) * S_e) / self.tau_e
input_i = self.w_ie * self.E - self.w_ii * self.I + Q_ext
S_i = self.sigmoid(input_i, self.a_i, self.theta_i)
dI = (-self.I + (1 - self.r_i * self.I) * S_i) / self.tau_i
# Hebbian weight updates
dw_ee = self.eta_ee * self.E * self.E - self.gamma_ee * self.w_ee
dw_ei = self.eta_ei * self.E * self.I - self.gamma_ei * self.w_ei
dw_ie = self.eta_ie * self.I * self.E - self.gamma_ie * self.w_ie
dw_ii = self.eta_ii * self.I * self.I - self.gamma_ii * self.w_ii
# Euler updates with noise
self.E += dE * dt + self.sigma * np.sqrt(dt) * self.rng.normal()
self.I += dI * dt + self.sigma * np.sqrt(dt) * self.rng.normal()
self.w_ee += dw_ee * dt
self.w_ei += dw_ei * dt
self.w_ie += dw_ie * dt
self.w_ii += dw_ii * dt
# Clamp
self.E = np.clip(self.E, 0.0, 1.0)
self.I = np.clip(self.I, 0.0, 1.0)
self.w_ee = np.clip(self.w_ee, 0.0, 30.0)
self.w_ei = np.clip(self.w_ei, 0.0, 30.0)
self.w_ie = np.clip(self.w_ie, 0.0, 30.0)
self.w_ii = np.clip(self.w_ii, 0.0, 30.0)
[docs]
def get_weights(self) -> Dict[str, float]:
return {"w_ee": self.w_ee, "w_ei": self.w_ei, "w_ie": self.w_ie, "w_ii": self.w_ii}
[docs]
class PlasticJansenRit:
"""
Jansen‑Rit column with plastic connectivity constants.
Uses empirical parameters.
"""
def __init__(self, jr_params=None, hebb_params=None, **kwargs):
from mercurial.params.empirical import HebbianParams, JansenRitParams
if jr_params is None:
jr_params = JansenRitParams()
if hebb_params is None:
hebb_params = HebbianParams()
self.a = kwargs.get("a", jr_params.a)
self.b = kwargs.get("b", jr_params.b)
self.A = kwargs.get("A", jr_params.A)
self.B = kwargs.get("B", jr_params.B)
self.C1 = kwargs.get("C1", jr_params.C1)
self.C2 = kwargs.get("C2", jr_params.C2)
self.C3 = kwargs.get("C3", jr_params.C3)
self.C4 = kwargs.get("C4", jr_params.C4)
self.e0 = kwargs.get("e0", jr_params.e0)
self.v0 = kwargs.get("v0", jr_params.v0)
self.r = kwargs.get("r", jr_params.r)
self.noise_amp = kwargs.get("noise_amp", 0.5)
self.eta_C1 = kwargs.get("eta_C1", hebb_params.learning_rate)
self.eta_C2 = kwargs.get("eta_C2", hebb_params.learning_rate)
self.eta_C3 = kwargs.get("eta_C3", hebb_params.learning_rate)
self.eta_C4 = kwargs.get("eta_C4", hebb_params.learning_rate)
self.gamma_C = kwargs.get("gamma_C", hebb_params.decay)
self.rng = np.random.default_rng()
self.state = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + 0.01 * self.rng.random(6)
[docs]
def sigmoid(self, v: float) -> float:
arg = self.r * (self.v0 - v)
if arg > 500:
return 0.0
if arg < -500:
return 2.0 * self.e0
return 2.0 * self.e0 / (1.0 + np.exp(arg))
[docs]
def derivative(self, state: np.ndarray, p_ext: float = 0.0) -> np.ndarray:
y0, y0dot, y1, y1dot, y2, y2dot = state
S_y0 = self.sigmoid(y0)
S_y1_minus_y2 = self.sigmoid(y1 - y2)
dy0 = y0dot
dy0dot = self.A * self.a * S_y1_minus_y2 + p_ext - 2 * self.a * y0dot - self.a**2 * y0
dy1 = y1dot
dy1dot = self.A * self.a * (self.C2 * S_y0 + p_ext) - 2 * self.a * y1dot - self.a**2 * y1
dy2 = y2dot
dy2dot = self.B * self.b * self.C4 * S_y0 - 2 * self.b * y2dot - self.b**2 * y2
return np.array([dy0, dy0dot, dy1, dy1dot, dy2, dy2dot])
[docs]
def step(self, dt: float, p_ext: float = 0.0) -> None:
k1 = self.derivative(self.state, p_ext)
state_pred = self.state + k1 * dt
k2 = self.derivative(state_pred, p_ext)
self.state += (k1 + k2) * dt / 2.0
self.state += self.noise_amp * np.sqrt(dt) * self.rng.normal(size=6)
# Hebbian updates for connectivity constants
y0, y1, y2 = self.state[0], self.state[2], self.state[4]
dC1 = self.eta_C1 * y0 * y1 - self.gamma_C * self.C1
dC2 = self.eta_C2 * y1 * y0 - self.gamma_C * self.C2
dC3 = self.eta_C3 * y0 * y2 - self.gamma_C * self.C3
dC4 = self.eta_C4 * y2 * y0 - self.gamma_C * self.C4
self.C1 += dC1 * dt
self.C2 += dC2 * dt
self.C3 += dC3 * dt
self.C4 += dC4 * dt
# Clamp to positive
self.C1 = max(0.1, self.C1)
self.C2 = max(0.1, self.C2)
self.C3 = max(0.1, self.C3)
self.C4 = max(0.1, self.C4)
[docs]
def get_EEG(self) -> float:
return self.state[0]