Machine learning #3-1 Decision Tree practice
Decision Tree에 대해 실제로 실습해보자.
Python에서는 sklearn library가 이를 제공하며, 실습을 위한 datasets도 제공하기 때문에 이를 이용해보자.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
cancer = load_breast_cancer()
datasets은 sklearn에서 제공하는 Bunch 형태로 받으며 구성은 다음과 같이 dictionary와 비슷하다.
dir(cancer) :
['DESCR', 'data', 'data_module', 'feature_names', 'filename', 'frame', 'target', 'target_names']
cancer.data : (569, 30)의 2차원 데이터로 569개의 data가 30개의 feature들로 나뉘어 있으며, 값은 모두 실수이다.
이 features들은 다음과 같다.
cancer.feature_names
array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',...,'worst symmetry', 'worst fractal dimension'], dtype='<U23')
cancer.target : (569, ) 로 1차원 데이터이며 마찬가지로 569개의 데이터에 대해 maliganant vs benign 인지 0, 1로 coding 되어 들어 있다.
다음으로 sklearn의 train_test_split module을 이용하여 데이터를 분리하자.
X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, stratify=cancer.target, random_state=42)
stratify 는 계층적 데이터 추출 옵션으로 cancer.target을 기준으로 층을 분류한 후
각 층에서 랜덤 데이터를 추출하는 옵션으로, 원래의 데이터 분포와 비슷하게 추출할 수 있게 해 준다.
random_state=42
세트를 섞을 때 해당 int 값(42)을 보고 섞으며, 하이퍼 파라미터를 튜닝시 이 값을 고정해두고 튜닝해야 매번 데이터셋이 변경되는 것을 방지할 수 있다.
이후 학습을 시작해보면
tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)
DecisionTreeClassifier 메소드로 tree 객체를 생성했는데, 이때 random_state는 난수 초기값이라고 생각하자.
이후 tree를 fit 하고 나면
tree.get_n_leaves() = 18
tree.get_depth() = 7
tree.get_params로 파라미터들의 설정 상태를 확인할 수 있다.
{'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'random_state': 0, 'splitter': 'best'}
의 구조를 갖는 tree가 생성되었음을 알 수 있다.
실제로 예측해보면
print("train set accuracy: {:.3f}".format(tree.score(X_train, y_train)))
print("test set accuracy: {:.3f}".format(tree.score(X_test, y_test)))
각각
train set accuracy: 1.000
test set accuracy: 0.937
으로 확인할 수 있다.
시각화
tree의 구조를 시각화하기 위해
from sklearn.tree import plot_tree
plot_tree(tree,
class_names=cancer.target_names,
feature_names=cancer.feature_names,
impurity=True, filled=True, rounded=True)
이때 그냥 하면 feature가 많아서 시각화가 잘 되지 않기 때문에
plt.figure(figsize=)를 충분히 키워서 진행한다.
이렇게 전체 트리 구조를 눈으로 확인할 수 있다.
인자 별 중요도
tree.feature_importances_ 메소드로 인자 별 중요도를 확인할 수 있다.
array([0. , 0.00752597, 0. , 0. , 0.00903116,
0. , 0.00752597, 0. , 0. , 0. ,
0.00975731, 0.04630969, 0. , 0.00238745, 0.00231135,
0. , 0. , 0. , 0. , 0.00668975,
0.69546322, 0.05383211, 0. , 0.01354675, 0. ,
0. , 0.01740312, 0.11684357, 0.01137258, 0. ])
이를 시각화하면
plt.figure(figsize=(20, 15))
plt.bar(cancer.feature_names, tree.feature_importances_)
'코딩 > Machine learning' 카테고리의 다른 글
Machine learning #4 Random forest (0) | 2022.03.08 |
---|---|
Machine learning #3 Decision tree (0) | 2022.03.08 |
Machine learning #2 - Image classification, CNN (0) | 2022.03.03 |
Machine learning #1 - MLP (0) | 2022.03.03 |