wide-residual-network
所属分类:Leetcode/题库
开发工具:Lua
文件大小:688KB
下载次数:0
上传日期:2017-01-13 01:44:19
上 传 者:
sh-1993
说明: 广泛的剩余网络实现。cifar10(97.12%)、cifar100(84.12%)和其他kaggle挑战的最佳结果
(Wide-residual network implementations. Best result for cifar10(97.12%), cifar100(84.12%), and other kaggle challenges)
文件列表:
IMAGES (0, 2017-01-13)
IMAGES\cifar100_image.png (238365, 2017-01-13)
IMAGES\cifar10_image.png (231057, 2017-01-13)
IMAGES\img_356.lua (336, 2017-01-13)
IMAGES\svhn_image.png (180204, 2017-01-13)
INSTALL.md (4999, 2017-01-13)
SERVER.md (2819, 2017-01-13)
checkpoints.lua (3392, 2017-01-13)
dataloader.lua (3394, 2017-01-13)
datasets (0, 2017-01-13)
datasets\cifar10-gen.lua (2045, 2017-01-13)
datasets\cifar10.lua (1418, 2017-01-13)
datasets\cifar100-gen.lua (2060, 2017-01-13)
datasets\cifar100.lua (1630, 2017-01-13)
datasets\init.lua (1351, 2017-01-13)
datasets\svhn-gen.lua (1479, 2017-01-13)
datasets\svhn.lua (1393, 2017-01-13)
datasets\transforms.lua (4917, 2017-01-13)
ensemble.lua (2738, 2017-01-13)
logs (0, 2017-01-13)
logs\cifar10 (0, 2017-01-13)
logs\cifar10\wide-resnet-28x10 (0, 2017-01-13)
logs\cifar10\wide-resnet-28x10\log_1.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x10\log_2.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x10\log_3.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x10\log_4.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x10\log_5.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x20 (0, 2017-01-13)
logs\cifar10\wide-resnet-28x20\log_1.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x20\log_2.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x20\log_3.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-28x20\log_4.txt (403, 2017-01-13)
logs\cifar10\wide-resnet-28x20\log_5.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-40x10 (0, 2017-01-13)
logs\cifar10\wide-resnet-40x10\log_1.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-40x10\log_2.txt (482, 2017-01-13)
logs\cifar10\wide-resnet-40x10\log_3.txt (482, 2017-01-13)
... ...
# Wide Residual Networks Using Ensemble
Wide-residual network implementations for cifar10, cifar100, and other kaggle challenges
Torch Implementation of Sergey Zagoruyko's [Wide Residual Networks](https://arxiv.org/pdf/1605.07146v2.pdf).
In order to figure out what 'width' & 'height' does on wide-residual networks,
several experiments were conducted on different settings of weights and heights.
It turns out that **increasing the number of filters(increasing width)** gave more positive influence
to the model than making the model deeper.
Last but not least, simply averaging a few models with different parameter settings showed a significant increase in both top1 and top5 accuracy. The CIFAR dataset test results approached to **97.12%** for CIFAR-10, and **84.19%** for CIFAR-100 with only **meanstd** normalization.
## Requirements
See the [installation instruction](INSTALL.md) for a step-by-step installation guide.
See the [server instruction](SERVER.md) for server setup.
- Install [Torch](http://torch.ch/docs/getting-started.html)
- Install [cuda-8.0](https://developer.nvidia.com/cuda-downloads)
- Install [cudnn v5.1](https://developer.nvidia.com/cudnn)
- Install luarocks packages
```bash
$ luarocks install cutorch
$ luarocks install xlua
$ luarocks install optnet
```
## Directions and datasets
- modelState : The best model will be saved in this directory
- datasets : Data preparation & preprocessing directory
- networks : Wide-residual network model structure file directory
- gen : Generated t7 file for each dataset will be saved in this directory
- scripts : Directory where the run file scripts are contained
## Best Results
CIFAR-10's top1 accuracy reaches to **97.12%** only with average ensembling without any weight adjustments.
Adapting weight adjustments for each model will promise a more improved accuracy.
You can see that the ensemble network improves the results of single WRNs.
Test error (%, random flip, **meanstd** normaliztion, median of 5 runs) on CIFAR:
| Dataset | network | Top1 Err(%) |
|:-----------:|:------------:|:------------:|
| CIFAR-10 | WRN-28x10 | 3.89 |
| CIFAR-10 | Ensemble-WRN | **2.88** |
| CIFAR-100 | WRN-28x10 | 18.85 |
| CIFAR-100 | Ensemble-WRN | **15.81** |
## How to run
You can train each dataset of either cifar10, cifar100 or svhn by running the script below.
```bash
$ ./scripts/[:dataset]_train.sh
# For example, if you want to train the model on cifar10, you simply type
$ ./scripts/cifar10_train.sh
```
You can test your own trained model of either cifar10, cifar100, svhn by running the script below.
```bash
$ ./scripts/[:dataset]_test.sh
```
To ensemble your multiple trained models of different parameters, follow the steps below.
```bash
$ vi ensemble.lua
# Press :32 in vi, which will move your cursor to line 32
ens_depth = torch.Tensor({28, 28, 28, 28, 40, 40, 40})
ens_widen_factor = torch.Tensor({20, 20, 20, 20, 10, 14, 14})
ens_nExperiment = torch.Tensor({ 2, 3, 4, 5, 5, 4, 5})
```
After you set each parameter for your models, open [scripts/ensemble.sh](scripts/ensemble.sh)
```bash
$ vi scripts/ensemble.sh
# on the second line
export dataset=[:dataset] # put the dataset you want to ensemble your models.
export mode=[:mode] # you can either choose 'avg', 'min', 'max'
```
Finally, run the script file.
```bash
$ ./scripts/ensemble.sh
```
## Implementation Details
* CIFAR-10, CIFAR-100
| epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov |
|:---------:|:-------------:|:-------------:|:---------:|:--------:|:--------:|
| 0 ~ 60 | 0.1 | 0.0005 | Momentum | 0.9 | true |
| 61 ~ 120 | 0.02 | 0.0005 | Momentum | 0.9 | true |
| 121 ~ 160 | 0.004 | 0.0005 | Momentum | 0.9 | true |
| 161 ~ 200 | 0.0008 | 0.0005 | Momentum | 0.9 | true |
* SVHN
| epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov |
|:---------:|:-------------:|:-------------:|:---------:|:--------:|:--------:|
| 0 ~ 80 | 0.01 | 0.0005 | Momentum | 0.9 | true |
| 81 ~ 120 | 0.001 | 0.0005 | Momentum | 0.9 | true |
| 121 ~ 160 | 0.0001 | 0.0005 | Momentum | 0.9 | true |
## CIFAR-10 Results
![alt tag](IMAGES/cifar10_image.png)
Below is the result of the test set accuracy for **CIFAR-10 dataset** training.
**Accuracy is the average of 5 runs**
| network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | accuracy(%) |
|:-----------------:|:-------:|:----------:|:-----:|:-----:|:------------:|:-----------:|
| pre-ResNet-1001 | 0 | meanstd | - | - | 3 min 25 sec | 95.08 |
| wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 95.84 |
| wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 96.01 |
| wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 96.19 |
| wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 10 sec | **96.52** |
| wide-resnet 40x10 | 0.3 | meanstd | 8.08G | - | 3 min 13 sec | 96.26 |
| wide-resnet 40x14 | 0.3 | meanstd | 7.37G | ***6G | 3 min 23 sec | 96.31 |
## CIFAR-100 Results
![alt tag](IMAGES/cifar100_image.png)
Below is the result of the test set accuracy for **CIFAR-100 dataset** training.
**Accuracy is the average of 5 runs**
| network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | Top1 acc(%)| Top5 acc(%) |
|:-----------------:|:-------:|:-----------:|:-----:|:-----:|:------------:|:----------:|:-----------:|
| pre-ResNet-1001 | 0 | meanstd | - | - | 3 min 25 sec | 77.29 | 93.44 |
| wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 80.03 | 95.01 |
| wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 81.01 | 95.44 |
| wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 81.47 | 95.53 |
| wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 05 sec | **82.43** | **96.02** |
| wide-resnet 40x10 | 0.3 | meanstd | 8.93G | - | 3 min 06 sec | 81.47 | 95.65 |
| wide-resnet 40x14 | 0.3 | meanstd | 7.39G | ***6G | 3 min 23 sec | 81.83 | 95.50 |
## SVHN Results
![alt tag](IMAGES/svhn_image.png)
Below is the result of the test set accrucay for **SVHN dataset** training.
**Accuracy is the result of 1 run**
| network | dropout | preprocess | GPU:0 | per epoch | Top1 acc(%)|
|:-----------------:|:-------:|:-----------:|:-----:|:-------------:|:----------:|
| wide-resnet 10x1 | 0.4 | meanstd | 0.91G | 1 min 37 sec | 93.815 |
| wide-resnet 10x8 | 0.4 | meanstd | 2.03G | 7 min 32 sec | 97.411 |
| wide-resnet 16x8 | 0.4 | meanstd | 2.92G | 14 min 8 sec | ***.229 |
| wide-resnet 22x8 | 0.4 | meanstd | 3.73G | 21 min 11 sec | ***.348 |
近期下载者:
相关文件:
收藏者: