Flux

From David's Wiki
The printable version is no longer supported and may have rendering errors. Please update your browser bookmarks and please use the default browser print function instead.
\( \newcommand{\P}[]{\unicode{xB6}} \newcommand{\AA}[]{\unicode{x212B}} \newcommand{\empty}[]{\emptyset} \newcommand{\O}[]{\emptyset} \newcommand{\Alpha}[]{Α} \newcommand{\Beta}[]{Β} \newcommand{\Epsilon}[]{Ε} \newcommand{\Iota}[]{Ι} \newcommand{\Kappa}[]{Κ} \newcommand{\Rho}[]{Ρ} \newcommand{\Tau}[]{Τ} \newcommand{\Zeta}[]{Ζ} \newcommand{\Mu}[]{\unicode{x039C}} \newcommand{\Chi}[]{Χ} \newcommand{\Eta}[]{\unicode{x0397}} \newcommand{\Nu}[]{\unicode{x039D}} \newcommand{\Omicron}[]{\unicode{x039F}} \DeclareMathOperator{\sgn}{sgn} \def\oiint{\mathop{\vcenter{\mathchoice{\huge\unicode{x222F}\,}{\unicode{x222F}}{\unicode{x222F}}{\unicode{x222F}}}\,}\nolimits} \def\oiiint{\mathop{\vcenter{\mathchoice{\huge\unicode{x2230}\,}{\unicode{x2230}}{\unicode{x2230}}{\unicode{x2230}}}\,}\nolimits} \)

Flux is a machine learning library for Julia


Usage

Basic Usage

using Flux;
using Flux.Tracker: update!;

// Create layers like this
inputSize = 10;
outputSize = 20;
myLayer = Dense(inputSize, outputSize, relu);
// You can get the weights using
myLayer.W;
myLayer.b;
// You can call the model to make a prediction
model(myData);
// Equivalent to 
relu(model.W * myData + model.b);


// Create Networks like this
model = Chain(
  Dense(10, 20, relu),
  Dense(20, 2)
);

// Calling the model will pass data through each layer.
model(myData);

// Get the parameters for the whole model with
p = params(model);

// Calculate the gradient with
gs = Tracker.gradient(function()
  predicted = model(myData);
  loss = sum((predicted - myLabels).^2)'
  return loss;
end, params(model));
// Gradient of layer 1 weights
gs[model[1].W];
// Update! will update the weights and clear the gradient
// Make sure to update all layers.
update!(model[1].W, -0.1 * gs[model[1].W]);

// You can also define an optimizer and update using the optimizer
opt = Adam()
update!(opt, model[1].W, gs[model[1].W])