Machine learning #3-1 Decision Tree practice

코딩/Machine learning|2022. 3. 8. 15:05

Decision Tree에 대해 실제로 실습해보자.

Python에서는 sklearn library가 이를 제공하며, 실습을 위한 datasets도 제공하기 때문에 이를 이용해보자.

 

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

cancer = load_breast_cancer()

 

datasets은 sklearn에서 제공하는 Bunch 형태로 받으며 구성은 다음과 같이 dictionary와 비슷하다. 

dir(cancer) :
['DESCR', 'data', 'data_module', 'feature_names', 'filename', 'frame', 'target', 'target_names']

 

cancer.data : (569, 30)의 2차원 데이터로 569개의 data가 30개의 feature들로 나뉘어 있으며, 값은 모두 실수이다. 

이 features들은 다음과 같다.

 

cancer.feature_names

array(['mean radius', 'mean texture', 'mean perimeter', 'mean area',...,'worst symmetry', 'worst fractal dimension'], dtype='<U23')

 

cancer.target : (569, ) 로 1차원 데이터이며 마찬가지로 569개의 데이터에 대해 maliganant vs benign 인지 0, 1로 coding 되어 들어 있다. 

 

다음으로 sklearn의 train_test_split module을 이용하여 데이터를 분리하자.

X_train, X_test, y_train, y_test = train_test_split(
    cancer.data, cancer.target, stratify=cancer.target, random_state=42)

stratify 는 계층적 데이터 추출 옵션으로 cancer.target을 기준으로 층을 분류한 후 

각 층에서 랜덤 데이터를 추출하는 옵션으로, 원래의 데이터 분포와 비슷하게 추출할 수 있게 해 준다. 

 

random_state=42

세트를 섞을 때 해당 int 값(42)을 보고 섞으며, 하이퍼 파라미터를 튜닝시 이 값을 고정해두고 튜닝해야 매번 데이터셋이 변경되는 것을 방지할 수 있다.

 

이후 학습을 시작해보면

tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)

DecisionTreeClassifier 메소드로 tree 객체를 생성했는데, 이때 random_state는 난수 초기값이라고 생각하자.

 

이후 tree를 fit 하고 나면

 

tree.get_n_leaves() = 18

tree.get_depth() = 7

tree.get_params로 파라미터들의 설정 상태를 확인할 수 있다. 

{'ccp_alpha': 0.0, 'class_weight': None, 'criterion': 'gini', 'max_depth': None, 'max_features': None, 'max_leaf_nodes': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 2, 'min_weight_fraction_leaf': 0.0, 'random_state': 0, 'splitter': 'best'}

의 구조를 갖는 tree가 생성되었음을 알 수 있다. 

 

실제로 예측해보면

print("train set accuracy: {:.3f}".format(tree.score(X_train, y_train)))
print("test set accuracy: {:.3f}".format(tree.score(X_test, y_test)))

각각

train set accuracy: 1.000
test set accuracy: 0.937

으로 확인할 수 있다. 

 

시각화

tree의 구조를 시각화하기 위해

from sklearn.tree import plot_tree
plot_tree(tree,
          class_names=cancer.target_names,
          feature_names=cancer.feature_names,
          impurity=True, filled=True, rounded=True)

이때 그냥 하면 feature가 많아서 시각화가 잘 되지 않기 때문에

plt.figure(figsize=)를 충분히 키워서 진행한다.

이렇게 전체 트리 구조를 눈으로 확인할 수 있다. 

 

인자 별 중요도

tree.feature_importances_ 메소드로 인자 별 중요도를 확인할 수 있다. 
array([0.        , 0.00752597, 0.        , 0.        , 0.00903116,
       0.        , 0.00752597, 0.        , 0.        , 0.        ,
       0.00975731, 0.04630969, 0.        , 0.00238745, 0.00231135,
       0.        , 0.        , 0.        , 0.        , 0.00668975,
       0.69546322, 0.05383211, 0.        , 0.01354675, 0.        ,
       0.        , 0.01740312, 0.11684357, 0.01137258, 0.        ])

 

이를 시각화하면

plt.figure(figsize=(20, 15))
plt.bar(cancer.feature_names, tree.feature_importances_)

 

'코딩 > Machine learning' 카테고리의 다른 글

Machine learning #4 Random forest  (0) 2022.03.08
Machine learning #3 Decision tree  (0) 2022.03.08
Machine learning #2 - Image classification, CNN  (0) 2022.03.03
Machine learning #1 - MLP  (0) 2022.03.03

댓글()