Core Ideas:
-
State Space Models (SSM) in Neural Networks: SSMs are used to model the relationship between input, output, and the internal state of a system. They are particularly effective in handling time-series data and sequences.
-
From Continuous to Discrete: SSMs are adapted from continuous-time models (used for analog signals) to discrete-time models (for digital signals) using methods like the bilinear transform, making them suitable for processing discrete input sequences.
-
Recurrent Representation and Efficiency: In their basic form, SSMs have a recurrent nature, which isn’t efficient for modern parallel processing hardware. This is overcome by transforming them into a discrete convolution form.
-
Convolutional Approach: By representing SSMs as discrete convolutions, they can be computed more efficiently on modern hardware like GPUs, enabling faster processing of large datasets.
-
SSM Neural Network Implementation: These concepts are implemented in neural networks using specialized layers and transformations, with applications in processing sequences, such as in time-series analysis or natural language processing.
Context and Relevance
-
Advances in Sequence Modeling: Traditional models like RNNs (Recurrent Neural Networks) face challenges in handling long sequences due to their sequential nature and limitations in memory. SSMs offer a more efficient and scalable alternative.
-
Efficiency and Parallelism: The transformation of SSMs into a form suitable for parallel processing aligns well with the capabilities of modern GPUs, leading to significant improvements in computational efficiency.
-
Application in Deep Learning: SSMs find applications in various domains within deep learning, such as language modeling, signal processing, and even in complex tasks like forecasting and anomaly detection in time-series data.
-
Theoretical Innovation: The integration of concepts like HiPPO matrices into SSMs represents a blend of theoretical innovation with practical application, enabling models to effectively capture and remember long-range dependencies in data.
-
Future of Machine Learning: The exploration and implementation of SSMs in neural network architectures contribute to the ongoing evolution of machine learning, particularly in making models more efficient, scalable, and capable of handling complex, sequential data.
Recent interesting paper to show how innovative this design really is
Notes on the Annotated S4
Goal is the efficient modeling of long sequences. We are going to build a new neural network layer based on State Space Models.
State Space Model (SSM) - 1-D input signal to a N-D latent state before projecting to a 1-D output signal
- - input signal 1-D
- u of t part emphasizes that the input is over time
- - output signal 1-D
- - latent state N-D
- multi-dimensional internal state of the State Space Model
- N-dimensional
-
- With D = 0, the main Idea is that we use the input and run it through B, which then we take the internal state x(t) and use A to generate the new state x’(t). With that new state we run it through C which then is the output y(t).
- A, B, C, D parameters learned over gradient descent
- D = 0, it is a skip connection; rather skip over Du(t)
Defining A, B, C
# initalize parameters A, B, C
def random_SSM(rng, N):
a_r, b_r, c_r = jax.random.split(rng, 3)
A = jax.random.uniform(a_r, (N, N))
B = jax.random.uniform(b_r, (N, 1))
C = jax.random.uniform(c_r, (1, N))
return A, B, C
Discrete-Time SSM: The Recurrent Representation
Idea: instead of the input being a continuous function u(t) we make the input . The input is discretized by a step size
We cannot take A, B, C and leave them the way they are. We have to use a bilinear method to convert the state matrix A in to an approximation Ref
# Take randomly generated parameters A, B, C and use a bilinear transformation
# so that the parameters work with a discrete time control model
def discretize(A, B, C, step):
I = np.eye(A.shape[0])
BL = inv(I - (step / 2.0) * A)
Ab = BL @ (I + (step / 2.0) * A)
Bb = (BL * step) @ B
return Ab, Bb, C
Once we discretize the model, it can be viewed and calculated like an RNN
def scan_SSM(Ab, Bb, Cb, u, x0):
def step(x_k_1, u_k):
x_k = Ab @ x_k_1 + Bb @ u_k
y_k = Cb @ x_k
return x_k, y_k
return jax.lax.scan(step, x0, u)
Putting everything together to run the SSM
def run_SSM(A, B, C, u):
L = u.shape[0]
N = A.shape[0]
Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
# Run recurrence
return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]
Core idea: Recurrent Neural Networks are slow because we cannot do multiple calculations in parallel, we have to do them sequentially. One after another. Because Convolutional Neural Networks are optimized on hardware, people try to turn an RNN into a CNN. They do this by “Unrolling” the RNN into a CNN. They do this by converting the RNN into a discrete convolution.
- Discrete Convolution is like blending two signals or inputs together to create a new sequence. The idea is to combine them while still retaining some of the values meaning. so like bit shifting the values then encoding the values in the open gap.
- We this concept implemented we can use parallelism of convolution operations
Here are the new equations:
where is the SSM convolutional Kernel
<Warning this code is bad and won’t work for small lengths> “Note that this is a giant filter. It is the size of the entire sequence!”
# This function computes the K kernal shown above
# it should be an array of the computed parameters A, B, C
def K_conv(Ab, Bb, Cb, L):
return np.array(
[(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
)
Here is the math to get those equations Reference equations
Since the main equation is , we still need to multiply them together. We can use Fast Fourier Transform (FFT) or Convolution theorem, to speed this up. To use this theorem we need to pad the input sequences with zeros then unpad the output sequence. As the length gets longer this FFT method will be more efficient than direct convolution.
# fast function for K * u
# this uses Fourier Transform (FFT) or Convolution theorem
def causal_convolution(u, K, nofft=False):
if nofft:
return convolve(u, K, mode="full")[: u.shape[0]]
else:
assert K.shape[0] == u.shape[0]
ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
out = ud * Kd
return np.fft.irfft(out)[: u.shape[0]]
# testing function
def test_cnn_is_rnn(N=4, L=16, step=1.0 / 16):
ssm = random_SSM(rng, N)
u = jax.random.uniform(rng, (L,))
jax.random.split(rng, 3)
# RNN
rec = run_SSM(*ssm, u)
# CNN
ssmb = discretize(*ssm, step=step)
conv = causal_convolution(u, K_conv(*ssmb, L))
# Check
assert np.allclose(rec.ravel(), conv.ravel())
SSM Neural Network
Discrete SSM defines a map from a 1-D sequence map. We assume that we are going to be learning parameters B and C, as well as step size and a scalar . For parameter we will be using a HiPPO matrix. We learn the step size in log space.
def log_step_initializer(dt_min=0.001, dt_max=0.1):
def init(key, shape):
return jax.random.uniform(key, shape) * (
np.log(dt_max) - np.log(dt_min)
) + np.log(dt_min)
return init
Most of the SSM layer work is building the kernel (filter).
the self.decode
specifies if the SSMLayer is in CNN mode or RNN mode
class SSMLayer(nn.Module):
N: int
l_max: int
decode: bool = False
def setup(self):
# SSM parameters
self.A = self.param("A", lecun_normal(), (self.N, self.N))
self.B = self.param("B", lecun_normal(), (self.N, 1))
self.C = self.param("C", lecun_normal(), (1, self.N))
self.D = self.param("D", nn.initializers.ones, (1,))
# Step parameter
self.log_step = self.param("log_step", log_step_initializer(), (1,))
step = np.exp(self.log_step)
self.ssm = discretize(self.A, self.B, self.C, step=step)
self.K = K_conv(*self.ssm, self.l_max)
# RNN cache for long sequences
self.x_k_1 = self.variable("cache", "cache_x_k", np.zeros, (self.N,))
def __call__(self, u):
if not self.decode:
# CNN Mode
return causal_convolution(u, self.K) + self.D * u
else:
# RNN Mode
x_k, y_s = scan_SSM(*self.ssm, u[:, np.newaxis], self.x_k_1.value)
if self.is_mutable_collection("cache"):
self.x_k_1.value = x_k
return y_s.reshape(-1).real + self.D * u
SSM operate on scalars, we make different, stacked copies
def cloneLayer(layer):
return nn.vmap(
layer,
in_axes=1,
out_axes=1,
variable_axes={"params": 1, "cache": 1, "prime": 1},
split_rngs={"params": True},
)
SSMLayer = cloneLayer(SSMLayer)
SSM Layer can then be put into a standard NN. We also add a block that pairs a call to an SSM with dropout and linear projection
class SequenceBlock(nn.Module):
layer_cls: nn.Module
layer: dict # Hyperparameters of inner layer
dropout: float
d_model: int
prenorm: bool = True
glu: bool = True
training: bool = True
decode: bool = False
def setup(self):
self.seq = self.layer_cls(**self.layer, decode=self.decode)
self.norm = nn.LayerNorm()
self.out = nn.Dense(self.d_model)
if self.glu:
self.out2 = nn.Dense(self.d_model)
self.drop = nn.Dropout(
self.dropout,
broadcast_dims=[0],
deterministic=not self.training,
)
def __call__(self, x):
skip = x
if self.prenorm:
x = self.norm(x)
x = self.seq(x)
x = self.drop(nn.gelu(x))
if self.glu:
x = self.out(x) * jax.nn.sigmoid(self.out2(x))
else:
x = self.out(x)
x = skip + self.drop(x)
if not self.prenorm:
x = self.norm(x)
return x
We then stack a bunch of the blocks on top of each other to produce a stack of SSM layers.
class Embedding(nn.Embed):
num_embeddings: int
features: int
@nn.compact
def __call__(self, x):
y = nn.Embed(self.num_embeddings, self.features)(x[..., 0])
return np.where(x > 0, y, 0.0)
class StackedModel(nn.Module):
layer_cls: nn.Module
layer: dict # Extra arguments to pass into layer constructor
d_output: int
d_model: int
n_layers: int
prenorm: bool = True
dropout: float = 0.0
embedding: bool = False # Use nn.Embed instead of nn.Dense encoder
classification: bool = False
training: bool = True
decode: bool = False # Probably should be moved into layer_args
def setup(self):
if self.embedding:
self.encoder = Embedding(self.d_output, self.d_model)
else:
self.encoder = nn.Dense(self.d_model)
self.decoder = nn.Dense(self.d_output)
self.layers = [
SequenceBlock(
layer_cls=self.layer_cls,
layer=self.layer,
prenorm=self.prenorm,
d_model=self.d_model,
dropout=self.dropout,
training=self.training,
decode=self.decode,
)
for _ in range(self.n_layers)
]
def __call__(self, x):
if not self.classification:
if not self.embedding:
x = x / 255.0 # Normalize
if not self.decode:
x = np.pad(x[:-1], [(1, 0), (0, 0)])
x = self.encoder(x)
for layer in self.layers:
x = layer(x)
if self.classification:
x = np.mean(x, axis=0)
x = self.decoder(x)
return nn.log_softmax(x, axis=-1)
BatchStackedModel = nn.vmap(
StackedModel,
in_axes=0,
out_axes=0,
variable_axes={"params": None, "dropout": None, "cache": 0, "prime": None},
split_rngs={"params": False, "dropout": True},
)
The full code is listed here
Problems with SSMs
- Randomly initialized SSM does not perform well
- Computing it naively like we’ve done so far is really slow and memory inefficient
Part 1b: Addressing Long-Range Dependencies with HiPPO
Prior Work found that SSMs dont work in practice because gradients scaling exponentially in the sequence length. HiPPO theory comes in to help. The idea is to use the HiPPO Matrix, which tries to memorize the history of the input. they define the most important matrix as a HiPPO matrix. The HiPPO Matrix is kind of complicated, look up for yourself, doesn’t seem to be very relevant to the understanding of S4.
benefits of making A an HiPPO Matrix:
- A only needs to be calculated once.
- Matrix aims to compress the past history into a state that has enough information to reconstruct the history.
Prior work found that it was very successful moving from random to a HiPPO matrix.
def make_HiPPO(N):
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
Diving deeper into HiPPO matrices
They are successful through coefficients of a Legendre polynomials. These coefficients let it approximate all of the previous history.
def example_legendre(N=8):
# Random hidden state as coefficients
import numpy as np
import numpy.polynomial.legendre
x = (np.random.rand(N) - 0.5) * 2
t = np.linspace(-1, 1, 100)
f = numpy.polynomial.legendre.Legendre(x)(t)
# Plot
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context("talk")
fig = plt.figure(figsize=(20, 10))
ax = fig.gca(projection="3d")
ax.plot(
np.linspace(-25, (N - 1) * 100 + 25, 100),
[0] * 100,
zs=-1,
zdir="x",
color="black",
)
ax.plot(t, f, zs=N * 100, zdir="y", c="r")
for i in range(N):
coef = [0] * N
coef[N - i - 1] = 1
ax.set_zlim(-4, 4)
ax.set_yticks([])
ax.set_zticks([])
# Plot basis function.
f = numpy.polynomial.legendre.Legendre(coef)(t)
ax.bar(
[100 * i],
[x[i]],
zs=-1,
zdir="x",
label="x%d" % i,
color="brown",
fill=False,
width=50,
)
ax.plot(t, f, zs=100 * i, zdir="y", c="b", alpha=0.5)
ax.view_init(elev=40.0, azim=-45)
fig.savefig("images/leg.png")
if False:
example_legendre()
Each is a coefficient for one element of the Legendre series shown as blue functions. The intuition is that the HiPPO matrix updates these coefficients each step.
S4 in Practice
Really cool experiments using the S4 in Practice.
Conclusions
My goal for this was to learn more about SSM and S4. Going Through the Annotated S4 was to understand enough about the foundations to go through Mamba - Linear-Time Sequence Modeling with Selective State Spaces Notes, which has been all of the hype lately. Innately the idea of a state plus all of the cognitive abilities that we are now seeing makes rational sense. Transformers only have so much “memory” in their context window. Each increase in the models context windows leads to a quadratic increase in size, this is not sustainable or smart. Future innovations will blend state and attention to create a more sensible structure.