簡體   English   中英

在 R 中繪制圖形

[英]Plotting a Graph in R

我將此站點用作參考https://www.r-bloggers.com/2021/02/how-to-build-a-handwritten-digit-classifier-with-r-and-random-forests/

使用帶有隨機森林的 R 編寫手寫數字分類器。

是否可以構建代碼末尾獲得的 colMeans 的 plot ? MNIST 訓練和測試數據集(您可以在上面的鏈接中找到)沒有任何列標題。 我是 R 的新手,還在學習。 任何形式的幫助將不勝感激。

這是代碼:

library(readr)

#loading the train and test sets of MNIST dataset 
train_set <- read_csv("mnist_train.csv", col_names = FALSE)
test_set <- read_csv("mnist_test.csv", col_names = FALSE)

#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[, 1]$X1)
test_labels <- as.factor(test_set[, 1]$X1)

#printing the first 10 labels
head(train_labels, 10)

#printing number of records for each digit (0 to 9)
summary(train_labels)

#importing random forest
library(randomForest)

#training the model
rf <- randomForest(x = train_set, y = train_labels, xtest = test_set, ntree = 50)
rf

#1- error rate
#represents the accuracy 
1 - mean(rf$err.rate)

#importing dplyr
library(dplyr)

#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
err_df %>%
    select(-"OOB") %>%
    colMeans()

colMeans 1的 Output

我通過對訓練集和測試集進行相當多的子集化來稍微修改您的代碼,以加快分析速度。 您可以自由評論/刪除相關行。 請查看下面的代碼,並告訴我這是否是您要查找的內容。

library(readr)
#importing dplyr
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#importing random forest
library(randomForest)
#> randomForest 4.6-14
#> Type rfNews() to see new features/changes/bug fixes.
#> 
#> Attaching package: 'randomForest'
#> The following object is masked from 'package:dplyr':
#> 
#>     combine
library(ggplot2)
#> 
#> Attaching package: 'ggplot2'
#> The following object is masked from 'package:randomForest':
#> 
#>     margin


#loading the train and test sets of MNIST dataset 
train_set <- read_csv("~/Downloads/mnist_train.csv", col_names = FALSE)
#> 
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#>   .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.
test_set <- read_csv("~/Downloads/mnist_test.csv", col_names = FALSE)
#> 
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#>   .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.

#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[, 1]$X1)
test_labels <- as.factor(test_set[, 1]$X1)

#printing the first 10 labels
head(train_labels, 10)
#>  [1] 5 0 4 1 9 2 1 3 1 4
#> Levels: 0 1 2 3 4 5 6 7 8 9

#printing number of records for each digit (0 to 9)
summary(train_labels)
#>    0    1    2    3    4    5    6    7    8    9 
#> 5923 6742 5958 6131 5842 5421 5918 6265 5851 5949

# reducing size
train_set <- train_set[ 1:1000, ]
train_labels <- train_labels[ 1:1000 ]
test_set <- test_set[ 1:100, ]
test_labels <- test_labels[ 1:100 ]

#training the model
rf <- randomForest(x = train_set, y = train_labels, xtest = test_set, ntree = 50)
rf
#> 
#> Call:
#>  randomForest(x = train_set, y = train_labels, xtest = test_set,      ntree = 50) 
#>                Type of random forest: classification
#>                      Number of trees: 50
#> No. of variables tried at each split: 28
#> 
#>         OOB estimate of  error rate: 11.6%
#> Confusion matrix:
#>    0   1  2  3  4  5  6   7  8  9 class.error
#> 0 96   0  0  0  0  0  1   0  0  0  0.01030928
#> 1  0 112  1  1  0  1  0   0  0  1  0.03448276
#> 2  2   6 82  0  2  0  1   4  2  0  0.17171717
#> 3  0   1  2 78  2  5  1   1  2  1  0.16129032
#> 4  0   0  1  0 94  1  2   1  1  5  0.10476190
#> 5  0   0  1  8  3 77  1   0  0  2  0.16304348
#> 6  1   0  1  0  2  2 86   1  1  0  0.08510638
#> 7  0   3  3  2  4  0  0 102  0  3  0.12820513
#> 8  0   1  1  3  1  6  1   1 71  2  0.18390805
#> 9  1   0  0  1  4  1  1   5  1 86  0.14000000

#1- error rate
#represents the accuracy 
1 - mean(rf$err.rate)
#> [1] 0.8012579


#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
mymeans <- err_df %>%
  select(-"OOB") %>%
  colMeans()

# I build a data.frame containing the indexes and the means
toplot <- data.frame( index = seq_len( length( mymeans ) ) - 1,
                      col_means = mymeans )

# this is to plot via ggplot2
ggplot( toplot, aes( x = index, y = col_means ) ) +
  geom_line() +
  geom_point() + 
  scale_x_continuous(breaks = seq_len( length( mymeans ) ) - 1 )

代表 package (v0.3.0) 於 2021 年 2 月 16 日創建

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM