BART-Int

所属分类:人工智能/神经网络/深度学习
开发工具:R
文件大小:0KB
下载次数:0
上传日期:2020-11-06 10:16:22
上 传 者sh-1993
说明:  贝叶斯概率数值积分与基于树的模型(使用贝叶斯加性回归树),
(Bayesian Probabilistic Numerical Integration with Tree-Based Models (using Bayesian Additive Regression Trees),)

文件列表:
BART-Int.Rproj (205, 2020-11-06)
Figures/ (0, 2020-11-06)
Figures/genz/ (0, 2020-11-06)
Figures/genz/1/ (0, 2020-11-06)
Figures/genz/2/ (0, 2020-11-06)
Figures/genz/3/ (0, 2020-11-06)
Figures/genz/4/ (0, 2020-11-06)
Figures/genz/5/ (0, 2020-11-06)
Figures/genz/6/ (0, 2020-11-06)
Figures/genz/7/ (0, 2020-11-06)
Figures/genz/9/ (0, 2020-11-06)
Figures/rares/ (0, 2020-11-06)
Figures/survey_design/ (0, 2020-11-06)
Figures/thumbnail.png (74769, 2020-11-06)
LICENSE (1080, 2020-11-06)
bart_compute_groundtruth.R (2653, 2020-11-06)
bcf/ (0, 2020-11-06)
bcf/plot.r (3739, 2020-11-06)
bcf/results/ (0, 2020-11-06)
bcf/results/results/ (0, 2020-11-06)
data/ (0, 2020-11-06)
data/extract.R (149, 2020-11-06)
data/sample_full.R (851, 2020-11-06)
figures_code/ (0, 2020-11-06)
figures_code/.draw_step.R.swp (16384, 2020-11-06)
figures_code/compute_CV.R (4881, 2020-11-06)
... ...

# Bayesian Probabilistic Numerical Integration with Tree-Based Models [Bayesian Probabilistic Numerical Integration with Tree-Based models](https://arxiv.org/abs/2006.05371) (to appear in NeurIPS 2020) Authors: Harrison Zhu, Franois-Xavier Briol, Xing Liu, Ruya Kang, Zhichao Shen, and Seth Flaxman ![](Figures/thumbnail.png) ## Code directory ## The `causal-inference` branch contains code relating to "Some Links Between Causal Inference and Bayesian Probabilistic Numerical Integration". . | └── README.md ├───├data: The scripts used to scrape survey design data and where to store the data. ├───├extract.R: Extract the PUMS survey dataset ├───├sample_full.R: Create 2 CSVs train2.csv and candidate2.csv that contains the needed features for the design and candidate points. ├───├Figures: Folder to hold results plots. ├───├genz ├───├1: Genz function results ├───├2: Genz function results ├───├3: Genz function results ├───├4: Genz function results ├───├5: Genz function results ├───├6: Genz function results ├───├7: Genz function results ├───├9: Genz function results ├───├survey_design ├───├results: Folder to hold results plots. This is where we store the results generated by our integral approximation functions, as well as the analytical integrals of the benchmark testing functions. ├───├genz ├───├1: Genz function results ├───├2: Genz function results ├───├3: Genz function results ├───├4: Genz function results ├───├5: Genz function results ├───├6: Genz function results ├───├7: Genz function results ├───├9: Genz function results ├───├survey_design ├───├figures_code ├───├draw_step.R: Draws Figure 1 in the paper ├───├plot_binary_response.R: Draws Figure 3 in the paper ├───├plot_computational_complexity.R: Draws Figure 2 in the paper ├───├plot_high_dimensionality.r: Computes the results in Table 2 in the paper ├───├plot_posterior_example.R: Draws Figure 4 in the paper/appendix ├───├step_design.R: Draws Figure 5 in the paper ├───├compute_CV.R: Computes the results in table 1 in the paper ├───├python: Python code for hyperparameter tuning for the GP ├───├gp_tune.py: class for GP regression with marginal likelihood maximisation ├───├src ├───├genz: Genz functions and its integrals ├───├analyticalIntegrals.R ├───├genz.R ├───├BARTInt.R: Implementation of BART-Int ├───├GPBQ.R: Implementation of Bayesian Quadrature with Gaussian processes (GP-BQ) ├───├monteCarloIntegration.R: Main class of Monte Carlo integration ├───├optimise_gp.R: Source file used to optimise the lengthscale using Pytorch with reticulate ├───├meanPopulationStudy: Source files used for Bayesian survey design ├───├bartMean.R ├───├gpMean.R ├───├integrationMain.R: Main class to do BART-Int, GPBQ and Monte Carlo integrations. Tweak your genz functions and parameters here ├───├poptMean_trained_bin.R: computes the ground truth proportions for the survey design problem ├───├saveComputeIntegrals.R: computes the exact integrals for the genz functions ├───├bart_compute_groundtruth.R: computes the ground truth for the survey design ## Dependencies The experiments are tested under Ubuntu18.04 and OSX. Docker images will be published in due course to ensure wider reproducibility. `R` dependencies. ```r MASS cubature lhs data.tree dbarts matrixStats mvtnorm doParallel kernlab msm MCMCglmm dbarts_0.9-8 caret reticulate rdist ``` `Python` dependencies ``` torch gpytorch ``` ## Numerical Experiments: Genz Functions 1) Install all the necessary packages ```r install.packages(c("MASS", "cubature", "lhs", "data.tree", "matrixStats", "mvtnorm", "doParallel", "kernlab", "msm", "MCMCglmm", "caret", "reticulate", "rdist")) # an old version of dbarts packageurl <- "https://cran.r-project.org/src/contrib/Archive/dbarts/dbarts_0.9-8.tar.gz" install.packages(packageurl, repos=NULL, type="source") ``` *Note that the older version of `dbarts` is needed as there had been significant changes in the class files for the data structures* Now for the Python dependencies, we will use following ``` gpytorch torch ``` This is done in `src/optimise_gp.R` by creating a virtualenv with the function `install_python_env()` using `reticulate`. 2) Save the computed integrals ``` Rscript saveComputeIntegrals.R ``` This will store the ground truth in `results/genz/integrals.csv` 3) To reproduce the benchmark tests, run `integrationMain.R` with customized inputs. There are 8 arguments in total, of which the last three are optional. The penultimate argument should only be specified when the step function is used (`genz_function_number = 7`), and is set to `1` if not specified. For example: ``` Rscript integrationMain.R dimension num_iterations genz_function_number kernel_name sequential_flag (measure) (number_of_jumps_for_step_function) (save_posterior) ``` where `genz_function_number` follows the indexing in this [documentation](https://www.sfu.ca/~ssurjano/integration.html) for the Genz families. The results will be stored in `results`, where you can find the `.csv` and `.RData` files containing the numerical values and the automatically generated graphs. For example, one could run this ``` Rscript integrationMain.R 1 2 1 matern32 1 uniform 1 1 ``` and get results for the continuous function, 2 iterations of sequential design, 1 dimension, matern32 kernel, 1 meaning with sequential design, uniform measure, 1 as a placeholder for the argument involving the number of jumps for a step function, and 1 to indicate whether to save the BART posterior samples at each iteration. As another example with a Gaussian prior and a step function as defined in the appendix of the paper, one could run ``` Rscript integrationMain.R 1 2 7 matern32 1 gaussian 1 1 ``` For more information about each input, check the first few lines of `integrationMain.R`. Results will also be stored in `results/genz` and `Figures/genz`. We ran the following to generate the results for Table 1 ``` for dim in 1 10 do for genz in 1 2 3 4 5 6 7 do Rscript integrationMain.R $dim 20 $genz matern32 uniform 1 1 1 done done ``` Although you can also rewrite `integrationMain.R` to parallelise each seed. For Table 2 ``` for dim in 1 10 20 100 do Rscript test_additive.R $dim 1 9 matern32 uniform 0 done ``` As for the graphs, we provide the scripts in `figures_code`. **Navigating the results** Check the [Genz Functions](https://www.sfu.ca/~ssurjano/integration.html) and see the preprint for more information. - 1: Continuous - 2: Corner Peak - 3: Discontinuous - 4: Gaussian Peak - 5: Oscillatory - 6: Product Peak - 7: Step function - 9: Additive Gaussian ## Numerical Experiments: Bayesian Survey Design 1) Install the dependencies in `R`. Make sure you are using **R 3.5.0** or higher. 2) Download and process the dataset ``` Rscript data/extract.R cd data;unzip 2016csv_pil.zip; cd ..; Rscript data/sample_full.R ``` This will create `train2.csv` and `candidate2.csv`, which store the possible initial design points and the candidate points. 3)[optional] To compute the BART groundtruth, run ``` Rscript bart_compute_groundtruth.R num_cv_start num_cv_end num_data num_design 1 ``` where `num_cv_start` and `num_cv_end` indicate a loop over possible random seeds with seeding `num_cv_start, num_cv_start+1,..,num_cv_end`, `num_data` is the number of candidate points and `num_design` is the number of initial design points. This can also be computed as the user wishes. The ground truths will be stored in `results/survey_design`. Alternatively, you can also just take the mean of the entire dataset of 454,816 points and that would yield very similar results. 4) To run the experiments, first navigate to `src/survey_design/gpMean.R` and change the jitter/nugget term according to what you deem is appropriate. We set it to what we obtain from the output of the maximum marginal likelihood estimator using line 84 in `poptMean_trained_bin.R`. Note that some small jitter is always needed for numerical stability during the kernel matrix inversion for GP-BQ. Then run ``` Rscript poptMean_trained_bin.R num_new_surveys num_cv_start num_cv_end num_data num_design ``` where `num_new_surveys` is the number of new surveys to query. For example, we used `num_new_surveys=200, num_data=10000, num_design=20`. This will generate and store the results in `results/survey_design` and `Figures/survey_design`, where you can find the `.csv` and `.RData` files containing the numerical values and the automatically generated graphs. *Note that the experiments can also be easily run using other BART packages such as `BART` or `bartMachine`, provided that `src/survey_design/bartMean.R` is edited so that `dbarts::bart` is replaced* To get the results for Table 3 ``` for num_cv in $(seq 1 20) do echo $num_cv Rscript src/meanPopulationStudy/poptMean_trained_bin.R 200 $num_cv $num_cv 10000 20 done ``` ## License MIT License Copyright (c) 2020 Imperial College London Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

近期下载者

相关文件


收藏者