Linear layer in artificial neural networks is an implementation of
where
- is the layers output,
- is an input vector,
- is the weight matrix and
- is the bias.
Learning the layer
In order to learn the network, we will need to compute gradients in the future. Gradients are necessary in the stochastic gradient descent used for training. To compute gradients for an arbitrary neural network, we use the back-propagation algorithm. This algorithm needs 2 functions - forward and backward.
Forward pass
Forward pass computes a forward message, which is the output of the linear layer. We can use numpy to compute the forward message for a batch of inputs described as a matrix .
return np.dot(X, self.W) + self.bWhat the code actually does, is computing
Y = XW + B = \begin{bmatrix} x_1^T W + b^T \\ x_2^T W + b^T \\ \vdots\\ x_{\text{n\_samples}}^T W + b^T \end{bmatrix}, \qquad \text{where } X = \begin{bmatrix} x_1^T \\ x_2^T \\ \vdots\\ x_{\text{n\_samples}}^T\end{bmatrix}, B = \begin{bmatrix} b^T \\ b^T \\ \vdots\\ b^T\end{bmatrix}. $$The output dimensions of $Y$ are therefore ${\text{n\_samples} \times \text{n\_outputs}}$. Note how the innocent looking `+ self.b` does what `numpy` calls [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html). ### Backward pass Following what we derived in [[Training a neural network]], backward pass computes the partial derivatives using following layers backward message $\delta$ and cached input $X$. We need to derive partial derivative w.r.t. weights $W$, bias $b$ and we need backward message for other layers.\frac{\partial Y_{ij}}{\partial W_{mn}}
```python self.dW = np.dot(self.X.T, dY) / self.X.shape[0] self.db = dY.mean(axis=0) dX = np.dot(dY, self.W.T) ``` ## Initialization ```python scale = np.sqrt(2.0 / self.n_inputs) self.W = self.rng.normal(0.0, scale, (self.n_inputs, self.n_outputs)) self.b = np.zeros(self.n_outputs) ``` ```python scale = np.sqrt(1.0 / self.n_inputs) self.W = self.rng.normal(0.0, scale, (self.n_inputs, self.n_outputs)) self.b = np.zeros(self.n_outputs) ```