Core Ideas:

  1. State Space Models (SSM) in Neural Networks: SSMs are used to model the relationship between input, output, and the internal state of a system. They are particularly effective in handling time-series data and sequences.

  2. From Continuous to Discrete: SSMs are adapted from continuous-time models (used for analog signals) to discrete-time models (for digital signals) using methods like the bilinear transform, making them suitable for processing discrete input sequences.

  3. Recurrent Representation and Efficiency: In their basic form, SSMs have a recurrent nature, which isn’t efficient for modern parallel processing hardware. This is overcome by transforming them into a discrete convolution form.

  4. Convolutional Approach: By representing SSMs as discrete convolutions, they can be computed more efficiently on modern hardware like GPUs, enabling faster processing of large datasets.

  5. SSM Neural Network Implementation: These concepts are implemented in neural networks using specialized layers and transformations, with applications in processing sequences, such as in time-series analysis or natural language processing.

Context and Relevance

  • Advances in Sequence Modeling: Traditional models like RNNs (Recurrent Neural Networks) face challenges in handling long sequences due to their sequential nature and limitations in memory. SSMs offer a more efficient and scalable alternative.

  • Efficiency and Parallelism: The transformation of SSMs into a form suitable for parallel processing aligns well with the capabilities of modern GPUs, leading to significant improvements in computational efficiency.

  • Application in Deep Learning: SSMs find applications in various domains within deep learning, such as language modeling, signal processing, and even in complex tasks like forecasting and anomaly detection in time-series data.

  • Theoretical Innovation: The integration of concepts like HiPPO matrices into SSMs represents a blend of theoretical innovation with practical application, enabling models to effectively capture and remember long-range dependencies in data.

  • Future of Machine Learning: The exploration and implementation of SSMs in neural network architectures contribute to the ongoing evolution of machine learning, particularly in making models more efficient, scalable, and capable of handling complex, sequential data.

Recent interesting paper to show how innovative this design really is

Notes on the Annotated S4

Goal is the efficient modeling of long sequences. We are going to build a new neural network layer based on State Space Models.

State Space Model (SSM) - 1-D input signal to a N-D latent state before projecting to a 1-D output signal

  • - input signal 1-D
    • u of t part emphasizes that the input is over time
  • - output signal 1-D
  • - latent state N-D
    • multi-dimensional internal state of the State Space Model
    • N-dimensional
    • With D = 0, the main Idea is that we use the input and run it through B, which then we take the internal state x(t) and use A to generate the new state x’(t). With that new state we run it through C which then is the output y(t).
  • A, B, C, D parameters learned over gradient descent
  • D = 0, it is a skip connection; rather skip over Du(t)

Defining A, B, C

# initalize parameters A, B, C
def random_SSM(rng, N):
    a_r, b_r, c_r = jax.random.split(rng, 3)
    A = jax.random.uniform(a_r, (N, N))
    B = jax.random.uniform(b_r, (N, 1))
    C = jax.random.uniform(c_r, (1, N))
    return A, B, C
Discrete-Time SSM: The Recurrent Representation

Idea: instead of the input being a continuous function u(t) we make the input . The input is discretized by a step size

We cannot take A, B, C and leave them the way they are. We have to use a bilinear method to convert the state matrix A in to an approximation Ref

# Take randomly generated parameters A, B, C and use a bilinear transformation
# so that the parameters work with a discrete time control model
def discretize(A, B, C, step):
    I = np.eye(A.shape[0])
    BL = inv(I - (step / 2.0) * A)
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

Once we discretize the model, it can be viewed and calculated like an RNN

def scan_SSM(Ab, Bb, Cb, u, x0):
    def step(x_k_1, u_k):
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k
 
    return jax.lax.scan(step, x0, u)

Putting everything together to run the SSM

def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
 
    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1]

Core idea: Recurrent Neural Networks are slow because we cannot do multiple calculations in parallel, we have to do them sequentially. One after another. Because Convolutional Neural Networks are optimized on hardware, people try to turn an RNN into a CNN. They do this by “Unrolling” the RNN into a CNN. They do this by converting the RNN into a discrete convolution.

  • Discrete Convolution is like blending two signals or inputs together to create a new sequence. The idea is to combine them while still retaining some of the values meaning. so like bit shifting the values then encoding the values in the open gap.
    • We this concept implemented we can use parallelism of convolution operations

Here are the new equations:
where is the SSM convolutional Kernel

<Warning this code is bad and won’t work for small lengths> “Note that this is a giant filter. It is the size of the entire sequence!”

# This function computes the K kernal shown above
# it should be an array of the computed parameters A, B, C
def K_conv(Ab, Bb, Cb, L):
    return np.array(
        [(Cb @ matrix_power(Ab, l) @ Bb).reshape() for l in range(L)]
    )

Here is the math to get those equations Reference equations

Since the main equation is , we still need to multiply them together. We can use Fast Fourier Transform (FFT) or Convolution theorem, to speed this up. To use this theorem we need to pad the input sequences with zeros then unpad the output sequence. As the length gets longer this FFT method will be more efficient than direct convolution.

# fast function for K * u
# this uses Fourier Transform (FFT) or Convolution theorem
def causal_convolution(u, K, nofft=False):
    if nofft:
        return convolve(u, K, mode="full")[: u.shape[0]]
    else:
        assert K.shape[0] == u.shape[0]
        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
        out = ud * Kd
        return np.fft.irfft(out)[: u.shape[0]]
# testing function 
def test_cnn_is_rnn(N=4, L=16, step=1.0 / 16):
    ssm = random_SSM(rng, N)
    u = jax.random.uniform(rng, (L,))
    jax.random.split(rng, 3)
    # RNN
    rec = run_SSM(*ssm, u)
 
    # CNN
    ssmb = discretize(*ssm, step=step)
    conv = causal_convolution(u, K_conv(*ssmb, L))
 
    # Check
    assert np.allclose(rec.ravel(), conv.ravel())

SSM Neural Network

Discrete SSM defines a map from a 1-D sequence map. We assume that we are going to be learning parameters B and C, as well as step size and a scalar . For parameter we will be using a HiPPO matrix. We learn the step size in log space.

def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(key, shape):
        return jax.random.uniform(key, shape) * (
            np.log(dt_max) - np.log(dt_min)
        ) + np.log(dt_min)
 
    return init

Most of the SSM layer work is building the kernel (filter).

the self.decode specifies if the SSMLayer is in CNN mode or RNN mode

class SSMLayer(nn.Module):
    N: int
    l_max: int
    decode: bool = False
 
    def setup(self):
        # SSM parameters
        self.A = self.param("A", lecun_normal(), (self.N, self.N))
        self.B = self.param("B", lecun_normal(), (self.N, 1))
        self.C = self.param("C", lecun_normal(), (1, self.N))
        self.D = self.param("D", nn.initializers.ones, (1,))
 
        # Step parameter
        self.log_step = self.param("log_step", log_step_initializer(), (1,))
 
        step = np.exp(self.log_step)
        self.ssm = discretize(self.A, self.B, self.C, step=step)
        self.K = K_conv(*self.ssm, self.l_max)
 
        # RNN cache for long sequences
        self.x_k_1 = self.variable("cache", "cache_x_k", np.zeros, (self.N,))
 
    def __call__(self, u):
        if not self.decode:
            # CNN Mode
            return causal_convolution(u, self.K) + self.D * u
        else:
            # RNN Mode
            x_k, y_s = scan_SSM(*self.ssm, u[:, np.newaxis], self.x_k_1.value)
            if self.is_mutable_collection("cache"):
                self.x_k_1.value = x_k
            return y_s.reshape(-1).real + self.D * u

SSM operate on scalars, we make different, stacked copies

def cloneLayer(layer):
    return nn.vmap(
        layer,
        in_axes=1,
        out_axes=1,
        variable_axes={"params": 1, "cache": 1, "prime": 1},
        split_rngs={"params": True},
    )
SSMLayer = cloneLayer(SSMLayer)

SSM Layer can then be put into a standard NN. We also add a block that pairs a call to an SSM with dropout and linear projection

class SequenceBlock(nn.Module):
    layer_cls: nn.Module
    layer: dict  # Hyperparameters of inner layer
    dropout: float
    d_model: int
    prenorm: bool = True
    glu: bool = True
    training: bool = True
    decode: bool = False
 
    def setup(self):
        self.seq = self.layer_cls(**self.layer, decode=self.decode)
        self.norm = nn.LayerNorm()
        self.out = nn.Dense(self.d_model)
        if self.glu:
            self.out2 = nn.Dense(self.d_model)
        self.drop = nn.Dropout(
            self.dropout,
            broadcast_dims=[0],
            deterministic=not self.training,
        )
 
    def __call__(self, x):
        skip = x
        if self.prenorm:
            x = self.norm(x)
        x = self.seq(x)
        x = self.drop(nn.gelu(x))
        if self.glu:
            x = self.out(x) * jax.nn.sigmoid(self.out2(x))
        else:
            x = self.out(x)
        x = skip + self.drop(x)
        if not self.prenorm:
            x = self.norm(x)
        return x

We then stack a bunch of the blocks on top of each other to produce a stack of SSM layers.

class Embedding(nn.Embed):
    num_embeddings: int
    features: int
 
    @nn.compact
    def __call__(self, x):
        y = nn.Embed(self.num_embeddings, self.features)(x[..., 0])
        return np.where(x > 0, y, 0.0)
class StackedModel(nn.Module):
    layer_cls: nn.Module
    layer: dict  # Extra arguments to pass into layer constructor
    d_output: int
    d_model: int
    n_layers: int
    prenorm: bool = True
    dropout: float = 0.0
    embedding: bool = False  # Use nn.Embed instead of nn.Dense encoder
    classification: bool = False
    training: bool = True
    decode: bool = False  # Probably should be moved into layer_args
 
    def setup(self):
        if self.embedding:
            self.encoder = Embedding(self.d_output, self.d_model)
        else:
            self.encoder = nn.Dense(self.d_model)
        self.decoder = nn.Dense(self.d_output)
        self.layers = [
            SequenceBlock(
                layer_cls=self.layer_cls,
                layer=self.layer,
                prenorm=self.prenorm,
                d_model=self.d_model,
                dropout=self.dropout,
                training=self.training,
                decode=self.decode,
            )
            for _ in range(self.n_layers)
        ]
 
    def __call__(self, x):
        if not self.classification:
            if not self.embedding:
                x = x / 255.0  # Normalize
            if not self.decode:
                x = np.pad(x[:-1], [(1, 0), (0, 0)])
        x = self.encoder(x)
        for layer in self.layers:
            x = layer(x)
        if self.classification:
            x = np.mean(x, axis=0)
        x = self.decoder(x)
        return nn.log_softmax(x, axis=-1)
BatchStackedModel = nn.vmap(
    StackedModel,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None, "dropout": None, "cache": 0, "prime": None},
    split_rngs={"params": False, "dropout": True},
)

The full code is listed here

Problems with SSMs

  1. Randomly initialized SSM does not perform well
  2. Computing it naively like we’ve done so far is really slow and memory inefficient

Part 1b: Addressing Long-Range Dependencies with HiPPO

Prior Work found that SSMs dont work in practice because gradients scaling exponentially in the sequence length. HiPPO theory comes in to help. The idea is to use the HiPPO Matrix, which tries to memorize the history of the input. they define the most important matrix as a HiPPO matrix. The HiPPO Matrix is kind of complicated, look up for yourself, doesn’t seem to be very relevant to the understanding of S4.

benefits of making A an HiPPO Matrix:

  1. A only needs to be calculated once.
  2. Matrix aims to compress the past history into a state that has enough information to reconstruct the history.

Prior work found that it was very successful moving from random to a HiPPO matrix.

def make_HiPPO(N):
    P = np.sqrt(1 + 2 * np.arange(N))
    A = P[:, np.newaxis] * P[np.newaxis, :]
    A = np.tril(A) - np.diag(np.arange(N))
    return -A

Diving deeper into HiPPO matrices

They are successful through coefficients of a Legendre polynomials. These coefficients let it approximate all of the previous history.

def example_legendre(N=8):
    # Random hidden state as coefficients
    import numpy as np
    import numpy.polynomial.legendre
 
    x = (np.random.rand(N) - 0.5) * 2
    t = np.linspace(-1, 1, 100)
    f = numpy.polynomial.legendre.Legendre(x)(t)
 
    # Plot
    import matplotlib.pyplot as plt
    import seaborn
 
    seaborn.set_context("talk")
    fig = plt.figure(figsize=(20, 10))
    ax = fig.gca(projection="3d")
    ax.plot(
        np.linspace(-25, (N - 1) * 100 + 25, 100),
        [0] * 100,
        zs=-1,
        zdir="x",
        color="black",
    )
    ax.plot(t, f, zs=N * 100, zdir="y", c="r")
    for i in range(N):
        coef = [0] * N
        coef[N - i - 1] = 1
        ax.set_zlim(-4, 4)
        ax.set_yticks([])
        ax.set_zticks([])
        # Plot basis function.
        f = numpy.polynomial.legendre.Legendre(coef)(t)
        ax.bar(
            [100 * i],
            [x[i]],
            zs=-1,
            zdir="x",
            label="x%d" % i,
            color="brown",
            fill=False,
            width=50,
        )
        ax.plot(t, f, zs=100 * i, zdir="y", c="b", alpha=0.5)
    ax.view_init(elev=40.0, azim=-45)
    fig.savefig("images/leg.png")
if False:
    example_legendre()

Each is a coefficient for one element of the Legendre series shown as blue functions. The intuition is that the HiPPO matrix updates these coefficients each step.

S4 in Practice

Really cool experiments using the S4 in Practice.

Conclusions

My goal for this was to learn more about SSM and S4. Going Through the Annotated S4 was to understand enough about the foundations to go through Mamba - Linear-Time Sequence Modeling with Selective State Spaces Notes, which has been all of the hype lately. Innately the idea of a state plus all of the cognitive abilities that we are now seeing makes rational sense. Transformers only have so much “memory” in their context window. Each increase in the models context windows leads to a quadratic increase in size, this is not sustainable or smart. Future innovations will blend state and attention to create a more sensible structure.