ml_algo

所属分类:数学计算
开发工具:Dart
文件大小:0KB
下载次数:0
上传日期:2022-12-30 18:03:29
上 传 者sh-1993
说明:  Dart编程语言中的机器学习算法
(Machine learning algorithms in Dart programming language)

文件列表:
CHANGELOG.md (19198, 2024-01-02)
CNAME (15, 2024-01-02)
LICENSE (1299, 2024-01-02)
_config.yml (25, 2024-01-02)
analysis_options.yaml (456, 2024-01-02)
benchmark/ (0, 2024-01-02)
benchmark/cross_validator.dart (1094, 2024-01-02)
benchmark/data/ (0, 2024-01-02)
benchmark/data/sample_data.json (1857683, 2024-01-02)
benchmark/decision_tree_classifier.dart (1289, 2024-01-02)
benchmark/inverse_logit_link_function.dart (969, 2024-01-02)
benchmark/kd_tree/ (0, 2024-01-02)
benchmark/kd_tree/kd_tree_building.dart (816, 2024-01-02)
benchmark/kd_tree/kd_tree_querying.dart (1235, 2024-01-02)
benchmark/knn_regressor.dart (1427, 2024-01-02)
benchmark/knn_solver.dart (1314, 2024-01-02)
benchmark/lasso_regressor.dart (855, 2024-01-02)
benchmark/linear_regressor.dart (1130, 2024-01-02)
benchmark/linear_regressor_gradient_descent.dart (1233, 2024-01-02)
benchmark/logistic_regressor_gradient.dart (1039, 2024-01-02)
benchmark/main.dart (429, 2024-01-02)
benchmark/random_binary_projection_searcher/ (0, 2024-01-02)
benchmark/random_binary_projection_searcher/searcher_building.dart (957, 2024-01-02)
benchmark/random_binary_projection_searcher/searcher_querying.dart (1569, 2024-01-02)
benchmark/sgd_regressor.dart (1115, 2024-01-02)
build.sh (45, 2024-01-02)
build.yaml (330, 2024-01-02)
e2e/ (0, 2024-01-02)
e2e/_datasets/ (0, 2024-01-02)
e2e/_datasets/advertising.csv (6150, 2024-01-02)
e2e/_datasets/housing.csv (41296, 2024-01-02)
e2e/decision_tree_classifier/ (0, 2024-01-02)
e2e/decision_tree_classifier/decision_tree_classifier_save_as_svg_test.dart (1316, 2024-01-02)
e2e/decision_tree_classifier/decision_tree_classifier_serialization_test.dart (857, 2024-01-02)
... ...

[![Build Status](https://github.com/gyrdym/ml_algo/workflows/CI%20pipeline/badge.svg)](https://github.com/gyrdym/ml_algo/actions?query=branch%3Amaster+) [![Coverage Status](https://coveralls.io/repos/github/gyrdym/ml_algo/badge.svg?branch=master)](https://coveralls.io/github/gyrdym/ml_algo?branch=master) [![pub package](https://img.shields.io/pub/v/ml_algo.svg)](https://pub.dartlang.org/packages/ml_algo) [![Gitter Chat](https://badges.gitter.im/gyrdym/gyrdym.svg)](https://gitter.im/gyrdym/) # Machine learning algorithms for Dart developers - ml_algo library The library is a part of the ecosystem: - [ml_algo library](https://github.com/gyrdym/ml_algo) - implementation of popular machine learning algorithms - [ml_preprocessing library](https://github.com/gyrdym/ml_preprocessing) - a library for data preprocessing - [ml_linalg library](https://github.com/gyrdym/ml_linalg) - a library for linear algebra - [ml_dataframe library](https://github.com/gyrdym/ml_dataframe)- a library for storing and manipulating data **Table of contents** - [What is ml_algo for](#what-is-ml_algo-for) - [The library content](#the-library-content) - [Examples](#examples) - [Logistic regression](#logistic-regression) - [Linear regression](#linear-regression) - [Decision tree-based classification](#decision-tree-based-classification) - [KDTree-based data retrieval](#kdtree-based-data-retrieval) - [Models retraining](#models-retraining) - [Notes on gradient-based optimisation algorithms](#a-couple-of-words-about-linear-models-which-use-gradient-optimisation-methods) - [Helpful articles on algorithms standing behind the library](#helpful-articles-on-algorithms-standing-behind-the-library) - [Contacts](#contacts) ## What is ml_algo for? The main purpose of the library is to give native Dart implementation of machine learning algorithms to those who are interested both in Dart language and data science. This library aims at Dart VM and Flutter, it's impossible to use it in web applications. ## The library content - #### Model selection - [CrossValidator](https://pub.dev/documentation/ml_algo/latest/ml_algo/CrossValidator-class.html). A factory that creates instances of cross validators. Cross-validation allows researchers to fit different [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) of machine learning algorithms assessing prediction quality on different parts of a dataset. - #### Classification algorithms - [LogisticRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor-class.html). A class that performs linear binary classification of data. To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability). - [LogisticRegressor.SGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.SGD.html). Implementation of the logistic regression algorithm based on stochastic gradient descent with L2 regularisation. To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability). - [LogisticRegressor.BGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.BGD.html). Implementation of the logistic regression algorithm based on batch gradient descent with L2 regularisation. To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability). - [LogisticRegressor.newton](https://pub.dev/documentation/ml_algo/latest/ml_algo/LogisticRegressor/LogisticRegressor.newton.html). Implementation of the logistic regression algorithm based on Newton-Raphson method with L2 regularisation. To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability). - [SoftmaxRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/SoftmaxRegressor-class.html). A class that performs linear multiclass classification of data. To use this kind of classifier your data has to be [linearly separable](https://en.wikipedia.org/wiki/Linear_separability). - [DecisionTreeClassifier](https://pub.dev/documentation/ml_algo/latest/ml_algo/DecisionTreeClassifier-class.html) A class that performs classification using decision trees. May work with data with non-linear patterns. - [KnnClassifier](https://pub.dev/documentation/ml_algo/latest/ml_algo/KnnClassifier-class.html) A class that performs classification using `k nearest neighbours algorithm` - it makes predictions based on the first `k` closest observations to the given one. - #### Regression algorithms - [LinearRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor-class.html). A general class for finding a linear pattern in training data and predicting outcomes as real numbers. - [LinearRegressor.lasso](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.lasso.html) Implementation of the linear regression algorithm based on coordinate descent with lasso regularisation - [LinearRegressor.SGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.SGD.html) Implementation of the linear regression algorithm based on stochastic gradient descent with L2 regularisation - [LinearRegressor.BGD](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.BGD.html) Implementation of the linear regression algorithm based on batch gradient descent with L2 regularisation - [LinearRegressor.newton](https://pub.dev/documentation/ml_algo/latest/ml_algo/LinearRegressor/LinearRegressor.newton.html) Implementation of the linear regression algorithm based on Newton-Raphson method with L2 regularisation - [KnnRegressor](https://pub.dev/documentation/ml_algo/latest/ml_algo/KnnRegressor-class.html) A class that makes predictions for each new observation based on the first `k` closest observations from training data. It may catch non-linear patterns of the data. - #### Clustering and retrieval algorithms - [KDTree](https://pub.dev/documentation/ml_algo/latest/kd_tree/KDTree-class.html) An algorithm for efficient data retrieval. - **Locality sensitive hashing.** A family of algorithms that randomly partition all reference data points into different bins, which makes it possible to perform efficient K Nearest Neighbours search, since there is no need to search for the neighbours through the entire data. The family is represented by the following classes: - [RandomBinaryProjectionSearcher](https://pub.dev/documentation/ml_algo/latest/ml_algo/RandomBinaryProjectionSearcher-class.html) For more information on the library's API, please visit the [API reference](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html) ## Examples ### Logistic regression Let's classify records from a well-known dataset - [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database) via [Logistic regressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/logistic_regressor/logistic_regressor.dart) **Important note:** Please pay attention to problems that classifiers and regressors exposed by the library solve. For e.g., [Logistic regressor](https://github.com/gyrdym/ml_algo/blob/master/lib/src/classifier/logistic_regressor/logistic_regressor.dart) solves only **binary classification** problems, and that means that you can't use this classifier with a dataset with more than two classes, keep that in mind - in order to find out more about regressors and classifiers, please refer to the [API documentation](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html) of the package Import all necessary packages. First, it's needed to ensure if you have `ml_preprocessing` and `ml_dataframe` packages in your dependencies: ```` dependencies: ml_dataframe: ^1.5.0 ml_preprocessing: ^7.0.2 ```` We need these repos to parse raw data in order to use it further. For more details, please visit [ml_preprocessing](https://github.com/gyrdym/ml_preprocessing) repository page. **Important note:** Regressors and classifiers exposed by the library do not handle strings, booleans and nulls, they can only deal with numbers! You necessarily need to convert all the improper values of your dataset to numbers, please refer to [ml_preprocessing](https://github.com/gyrdym/ml_preprocessing) library to find out more about data preprocessing. ````dart import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_preprocessing/ml_preprocessing.dart'; ```` ### Read a dataset's file We have 2 options here: - Download the dataset from [Pima Indians Diabetes Database](https://www.kaggle.com/uciml/pima-indians-diabetes-database).
Instructions #### For a desktop application: Just provide a proper path to your downloaded file and use a function-factory `fromCsv` from `ml_dataframe` package to read the file: ````dart final samples = await fromCsv('datasets/pima_indians_diabetes_database.csv'); ```` #### For a flutter application: It's needed to add the dataset to the flutter assets by adding the following config in the pubspec.yaml: ```` flutter: assets: - assets/datasets/pima_indians_diabetes_database.csv ```` You need to create the assets directory in the file system and put the dataset's file there. After that you can access the dataset: ```dart import 'package:flutter/services.dart' show rootBundle; import 'package:ml_dataframe/ml_dataframe.dart'; void main() async { final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv'); final samples = DataFrame.fromRawCsv(rawCsvContent); } ```
- Or we may simply use [getPimaIndiansDiabetesDataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/getPimaIndiansDiabetesDataFrame.html) function from [ml_dataframe](https://pub.dev/packages/ml_dataframe) package. The function returns a ready to use [DataFrame](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/DataFrame-class.html) instance filled with `Pima Indians Diabetes Database` data.
Instructions ```dart import 'package:ml_dataframe/ml_dataframe.dart'; void main() { final samples = getPimaIndiansDiabetesDataFrame(); } ```
### Prepare datasets for training and testing Data in this file is represented by 768 records and 8 features. The 9th column is a label column, it contains either 0 or 1 on each row. This column is our target - we should predict a class label for each observation. The column's name is `Outcome`. Let's store it: ````dart final targetColumnName = 'Outcome'; ```` Now it's the time to prepare data splits. Since we have a smallish dataset (only 768 records), we can't afford to split the data into just train and test sets and evaluate the model on them, the best approach in our case is Cross-Validation. According to this, let's split the data in the following way using the library's [splitData](https://github.com/gyrdym/ml_algo/blob/master/lib/src/model_selection/split_data.dart) function: ```dart final splits = splitData(samples, [0.7]); final validationData = splits[0]; final testData = splits[1]; ``` `splitData` accepts a `DataFrame` instance as the first argument and ratio list as the second one. Now we have 70% of our data as a validation set and 30% as a test set for evaluating generalization errors. ### Set up a model selection algorithm Then we may create an instance of `CrossValidator` class to fit the [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) of our model. We should pass validation data (our `validationData` variable), and a number of folds into CrossValidator constructor. ````dart final validator = CrossValidator.kFold(validationData, numberOfFolds: 5); ```` Let's create a factory for the classifier with desired hyperparameters. We have to decide after the cross-validation if the selected hyperparameters are good enough or not: ```dart final createClassifier = (DataFrame samples) => LogisticRegressor( samples targetColumnName, ); ``` If we want to evaluate the learning process more thoroughly, we may pass `collectLearningData` argument to the classifier constructor: ```dart final createClassifier = (DataFrame samples) => LogisticRegressor( ..., collectLearningData: true, ); ``` This argument activates collecting costs per each optimization iteration, and you can see the cost values right after the model creation. ### Evaluate the performance of the model Assume, we chose perfect hyperparameters. In order to validate this hypothesis, let's use CrossValidator instance created before: ````dart final scores = await validator.evaluate(createClassifier, MetricType.accuracy); ```` Since the CrossValidator instance returns a [Vector](https://github.com/gyrdym/ml_linalg/blob/master/lib/vector.dart) of scores as a result of our predictor evaluation, we may choose any way to reduce all the collected scores to a single number, for instance, we may use Vector's `mean` method: ```dart final accuracy = scores.mean(); ``` Let's print the score: ````dart print('accuracy on k fold validation: ${accuracy.toStringAsFixed(2)}'); ```` We can see something like this: ```` accuracy on k fold validation: 0.75 ```` Let's assess our hyperparameters on the test set in order to evaluate the model's generalization error: ```dart final testSplits = splitData(testData, [0.8]); final classifier = createClassifier(testSplits[0]); final finalScore = classifier.assess(testSplits[1], MetricType.accuracy); ``` The final score is like: ```dart print(finalScore.toStringAsFixed(2)); // approx. 0.75 ``` If we specified `collectLearningData` parameter, we may see costs per each iteration in order to evaluate how our cost changed from iteration to iteration during the learning process: ```dart print(classifier.costPerIteration); ``` ### Write the model to a json file Seems, our model has a good generalization ability, and that means we may use it in the future. To do so we may store the model in a file as JSON: ```dart await classifier.saveAsJson('diabetes_classifier.json'); ``` After that we can simply read the model from the file and make predictions: ```dart import 'dart:io'; void main() { // ... final fileName = 'diabetes_classifier.json'; final file = File(fileName); final encodedModel = await file.readAsString(); final classifier = LogisticRegressor.fromJson(encodedModel); final unlabelledData = await fromCsv('some_unlabelled_data.csv'); final prediction = classifier.predict(unlabelledData); print(prediction.header); // ('class variable (0 or 1)') print(prediction.rows); // [ // (1), // (0), // (0), // (1), // ..., // (1), // ] // ... } ``` Please note that all the hyperparameters that we used to generate the model are persisted as the model's read-only fields, and we can access them anytime: ```dart print(classifier.iterationsLimit); print(classifier.probabilityThreshold); // and so on ```
All the code for a desktop application: ````dart import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_preprocessing/ml_preprocessing.dart'; void main() async { // Another option - to use a toy dataset: // final samples = getPimaIndiansDiabetesDataFrame(); final samples = await fromCsv('datasets/pima_indians_diabetes_database.csv', headerExists: true); final targetColumnName = 'Outcome'; final splits = splitData(samples, [0.7]); final validationData = splits[0]; final testData = splits[1]; final validator = CrossValidator.kFold(validationData, numberOfFolds: 5); final createClassifier = (DataFrame samples) => LogisticRegressor( samples targetColumnName, ); final scores = await validator.evaluate(createClassifier, MetricType.accuracy); final accuracy = scores.mean(); print('accuracy on k fold validation: ${accuracy.toStringAsFixed(2)}'); final testSplits = splitData(testData, [0.8]); final classifier = createClassifier(testSplits[0], targetNames); final finalScore = classifier.assess(testSplits[1], targetNames, MetricType.accuracy); print(finalScore.toStringAsFixed(2)); await classifier.saveAsJson('diabetes_classifier.json'); } ````
All the code for a flutter application: ````dart import 'package:flutter/services.dart' show rootBundle; import 'package:ml_algo/ml_algo.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_preprocessing/ml_preprocessing.dart'; void main() async { final rawCsvContent = await rootBundle.loadString('assets/datasets/pima_indians_diabetes_database.csv'); // Another option - to use a toy dataset: // final samples = getPimaIndiansDiabetesDataFrame(); final samples = DataFrame.fromRawCsv(rawCsvContent); final targetColumnName = 'Outcome'; final splits = splitData(samples, [0.7]); final validationData = splits[0]; final testData = splits[1]; final validator = CrossValidator.kFold(validationData, numberOfFolds: 5); final createClassifier = (DataFrame samples) => LogisticRegressor( samples targetColumnName, ); final scores = await validator.evaluate(createClassifier, MetricType.accuracy); final accuracy = scores.mean(); print('accuracy on k fold validation: ${accuracy.toStringAsFixed(2)}'); final testSplits = splitData(testData, [0.8]); final classifier = createClassifier(testSplits[0], targetNames); final finalScore = classifier.assess(testSplits[1], targetNames, MetricType.accuracy); print(finalScore.toStringAsFixed(2)); await classifier.saveAsJson('diabetes_classifier.json'); } ````
### Linear regression Let's try to predict house prices using linear regression and the famous [Boston Housing](https://www.kaggle.com/c/boston-housing) dataset. The dataset contains 13 independent variables and 1 dependent variable - `medv` which is the target one (you can find the dataset in [e2e/_datasets/housing.csv](https://github.com/gyrdym/ml_algo/blob/master/e2e/_datasets/housing.csv)). Again, first we need to download the file and create a dataframe. The dataset is headless, we may either use autoheader or provide our own header. Let's use autoheader in our example: #### For a desktop application: Just provide a proper path to your downloaded file and use a function-factory `fromCsv` from `ml_dataframe` package to read the file: ```dart final samples = await fromCsv('datasets/housing.csv', headerExists: false, columnDelimiter: ' '); ``` #### For a flutter application: It's needed to add the dataset to the flutter assets by adding the following config in the pubspec.yaml: ```` flutter: assets: - assets/datasets/housing.csv ```` You need to create the assets dire ... ...

近期下载者

相关文件


收藏者