Support vector Machine (SVM) is a prominent machine learning technique that is used for classification and regression problems. SVM's fundamental principle is to identify a hyperplane that divides the data points into two classes with the greatest margin. The distance between the hyperplane and the nearest data points of the two classes is known as the margin. The closer the hyperplane is to the data points, the more probable overfitting occurs, but a bigger margin can result in better generalization.
SVM is useful for linear classification jobs in which the data elements can be separated by a straight line. If the data points aren't able to be separated linearly, then SVM can use a kernel function to move them to a higher-dimensional feature area, where they might become separable. SVM can employ a variety of kernel functions, including linear, polynomial, and radial basis functions. (RBF).
The SVM method is broken down into the following steps:
- Determine the data point groups.
- Find the hyperplane with the greatest range of separation between the two groups.
- If the data elements cannot be separated linearly, use a kernel function to transfer them to a higher-dimensional feature space.
- Train the model by determining which factors maximize the margin while minimizing categorization error.
- Predict new data points by categorizing them according to their location relative to the hyperplane.
library(e1071)
library(caTools)
library(ggplot2)
library(dplyr)
library(pROC)
library(mlbench)
data(BreastCancer)
summary(BreastCancer)
cancer = BreastCancer %>%
select(-Id) %>%
na.omit()
head(cancer)
set.seed(123)
split = sample.split(cancer$Class, SplitRatio = 0.7)
train = subset(cancer, split == TRUE)
test = subset(cancer, split == FALSE)
svm_model = svm(Class ~ ., data = train, kernel = "radial")
pred = predict(svm_rbf, newdata = test)
pred
table(pred, test$Class)
accuracy = sum(diag(table(pred, test$Class))) / sum(table(pred, test$Class))
precision = diag(table(pred, test$Class)) / colSums(table(pred, test$Class))
recall = diag(table(pred, test$Class)) / rowSums(table(pred, test$Class))
f1 = 2 * (precision * recall) / (precision + recall)
metrics = data.frame(Accuracy = accuracy, Precision = precision, Recall = recall, F1 = f1)
print(metrics)
# Calculate ROC curve
roc_curve = roc(test$Class, as.numeric(pred))
# Plot ROC curve
plot(roc_curve, print.auc = TRUE, legacy.axes = TRUE, grid=c(0.1, 0.2),
grid.col=c("lightgray", "lightgray", "black", "black"),
grid.lty=c(1, 1, 1, 1), grid.lwd=c(1, 1, 1, 1),
auc.polygon = TRUE, max.auc.polygon = TRUE, auc.polygon.col = "skyblue",
print.thres = c(0.1, 0.2, 0.5), print.thres.col = "black",
print.thres.cex = 1.2, print.thres.adj = c(1.6, -0.6))