Skip to content
This repository was archived by the owner on Oct 28, 2021. It is now read-only.

sudachen/go-dnn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Apr 26, 2020
814de94 · Apr 26, 2020

History

57 Commits
Aug 25, 2019
Sep 13, 2019
Aug 15, 2019
Sep 20, 2019
Sep 20, 2019
Sep 20, 2019
Sep 20, 2019
Sep 20, 2019
Aug 15, 2019
Aug 15, 2019
Aug 15, 2019
Feb 10, 2020
Aug 25, 2019
Apr 26, 2020
Aug 31, 2019
Aug 31, 2019

Repository files navigation

Go Report Card License

It's an old version of dnn fo golang. Please see new updated version https://github.com/go-ml-dev/nn

import (
	"github.com/sudachen/go-dnn/data/mnist"
	"github.com/sudachen/go-dnn/fu"
	"github.com/sudachen/go-dnn/mx"
	"github.com/sudachen/go-dnn/ng"
	"github.com/sudachen/go-dnn/nn"
	"gotest.tools/assert"
	"testing"
	"time"
)

var mnistConv0 = nn.Connect(
	&nn.Convolution{Channels: 24, Kernel: mx.Dim(3, 3), Activation: nn.ReLU},
	&nn.MaxPool{Kernel: mx.Dim(2, 2), Stride: mx.Dim(2, 2)},
	&nn.Convolution{Channels: 32, Kernel: mx.Dim(5, 5), Activation: nn.ReLU, BatchNorm: true},
	&nn.MaxPool{Kernel: mx.Dim(2, 2), Stride: mx.Dim(2, 2)},
	&nn.FullyConnected{Size: 32, Activation: nn.Swish, BatchNorm: true},
	&nn.FullyConnected{Size: 10, Activation: nn.Softmax})

func Test_mnistConv0(t *testing.T) {

	gym := &ng.Gym{
		Optimizer: &nn.Adam{Lr: .001},
		Loss:      &nn.LabelCrossEntropyLoss{},
		Input:     mx.Dim(32, 1, 28, 28),
		Epochs:    5,
		Verbose:   ng.Printing,
		Every:     1 * time.Second,
		Dataset:   &mnist.Dataset{},
		Metric:    &ng.Classification{Accuracy: 0.98},
		Seed:      42,
	}

	acc, params, err := gym.Train(mx.CPU, mnistConv0)
	assert.NilError(t, err)
	assert.Assert(t, acc >= 0.98)
	err = params.Save(fu.CacheFile("tests/mnistConv0.params"))
	assert.NilError(t, err)

	net, err := nn.Bind(mx.CPU, mnistConv0, mx.Dim(10, 1, 28, 28), nil)
	assert.NilError(t, err)
	err = net.LoadParamsFile(fu.CacheFile("tests/mnistConv0.params"), false)
	assert.NilError(t, err)
	_ = net.PrintSummary(false)

	ok, err := ng.Measure(net, &mnist.Dataset{}, &ng.Classification{Accuracy: 0.98}, ng.Printing)
	assert.Assert(t, ok)
}
Network Identity: 158cf5bd604e12e7bd438084e135703bd89dc10f
Symbol              | Operation            | Output        |  Params #
----------------------------------------------------------------------
_input              | null                 | (32,1,28,28)  |         0
Convolution01       | Convolution((3,3)//) | (32,24,26,26) |       240
Convolution01$A     | Activation(relu)     | (32,24,26,26) |         0
MaxPool02           | Pooling(max)         | (32,24,13,13) |         0
Convolution03       | Convolution((5,5)//) | (32,32,9,9)   |     19232
Convolution03$BN    | BatchNorm            | (32,32,9,9)   |       128
Convolution03$A     | Activation(relu)     | (32,32,9,9)   |         0
MaxPool04           | Pooling(max)         | (32,32,4,4)   |         0
FullyConnected05    | FullyConnected       | (32,32)       |     16416
FullyConnected05$BN | BatchNorm            | (32,32)       |       128
sigmoid@sym07       | sigmoid              | (32,32)       |         0
FullyConnected05$A  | elemwise_mul         | (32,32)       |         0
FullyConnected06    | FullyConnected       | (32,10)       |       330
FullyConnected06$A  | SoftmaxActivation()  | (32,10)       |         0
BlockGrad@sym08     | BlockGrad            | (32,10)       |         0
make_loss@sym09     | make_loss            | (32,10)       |         0
pick@sym10          | pick                 | (32,1)        |         0
log@sym11           | log                  | (32,1)        |         0
_mul_scalar@sym12   | _mul_scalar          | (32,1)        |         0
mean@sym13          | mean                 | (1)           |         0
make_loss@sym14     | make_loss            | (1)           |         0
----------------------------------------------------------------------
Total params: 36474
[000] batch: 389, loss: 0.09991227
[000] batch: 1074, loss: 0.055281825
[000] batch: 1855, loss: 0.0760978
[000] metric: 0.988, final loss: 0.0515
Achieved reqired metric
Symbol              | Operation            | Output        |  Params #
----------------------------------------------------------------------
_input              | null                 | (10,1,28,28)  |         0
Convolution01       | Convolution((3,3)//) | (10,24,26,26) |       240
Convolution01$A     | Activation(relu)     | (10,24,26,26) |         0
MaxPool02           | Pooling(max)         | (10,24,13,13) |         0
Convolution03       | Convolution((5,5)//) | (10,32,9,9)   |     19232
Convolution03$BN    | BatchNorm            | (10,32,9,9)   |       128
Convolution03$A     | Activation(relu)     | (10,32,9,9)   |         0
MaxPool04           | Pooling(max)         | (10,32,4,4)   |         0
FullyConnected05    | FullyConnected       | (10,32)       |     16416
FullyConnected05$BN | BatchNorm            | (10,32)       |       128
sigmoid@sym07       | sigmoid              | (10,32)       |         0
FullyConnected05$A  | elemwise_mul         | (10,32)       |         0
FullyConnected06    | FullyConnected       | (10,10)       |       330
FullyConnected06$A  | SoftmaxActivation()  | (10,10)       |         0
----------------------------------------------------------------------
Total params: 36474
Accuracy over 1000*10 batchs: 0.988
--- PASS: Test_mnistConv0 (6.51s)

About

Deep Neural Networks for Golang (powered by MXNet). The new updated version - https://github.com/go-ml-dev/nn

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages