You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Goro is a high-level machine learning library for Go built on Gorgonia. It aims to have the same feel as Keras.
Usage
import (
. "github.com/aunum/goro/pkg/v1/model""github.com/aunum/goro/pkg/v1/layer"
)
// create the 'x' input e.g. mnist imagex:=NewInput("x", []int{1, 28, 28})
// create the 'y' or expect output e.g. labelsy:=NewInput("y", []int{10})
// create a new sequential model with the name 'mnist'model, _:=NewSequential("mnist")
// add layers to the modelmodel.AddLayers(
layer.Conv2D{Input: 1, Output: 32, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Conv2D{Input: 32, Output: 64, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Conv2D{Input: 64, Output: 128, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Flatten{},
layer.FC{Input: 128*3*3, Output: 100},
layer.FC{Input: 100, Output: 10, Activation: layer.Softmax},
)
// pick an optimizeroptimizer:=g.NewRMSPropSolver()
// compile the model with optionsmodel.Compile(xi, yi,
WithOptimizer(optimizer),
WithLoss(m.CrossEntropy),
WithBatchSize(100),
)
// fit the modelmodel.Fit(xTrain, yTrain)
// use the model to predict an 'x'prediction, _:=model.Predict(xTest)
// fit the model with a batchmodel.FitBatch(xTrainBatch, yTrainBatch)
// use the model to predict a batch of 'x'prediction, _=model.PredictBatch(xTestBatch)
Examples
See the examples folder for example implementations.
There are many examples in the reinforcement learning library Gold.
Docs
Each package contains a README explaining the usage, also see GoDoc.
Contributing
Please open an MR for any issues or feature requests.