[텍스트 분석] 다항분류
실습 준비
데이터
from sklearn.datasets import fetch_20newsgroups
news = fetch_20newsgroups()
news['target_names']
['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'misc.forsale',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
문서 단어 행렬
from sklearn.feature_extraction.text import CountVectorizer
cv = CountVectorizer(min_df=0.01, max_df=0.5, stop_words='english')
dtm = cv.fit_transform(news['data'])
dtm.shape
(11314, 1916)
x와 y 지정
x = dtm
y = news['target']
분할
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=0.2, random_state=42)
Logistic Regression
사이킷런의 로지스틱 회귀분석은 OvR을 지원
from sklearn.linear_model import LogisticRegression
model = LogisticRegression(multi_className='ovr', max_iter=200)
model.fit(x_train, y_train)
multi_className=’ovr’로 설정(현재 기본값이 OvR이므로 생략해도 무방)
prob = model.predict_proba(x_test)
prob[0]
array([2.06979569e-01, 4.98778872e-04, 1.51227474e-01, 2.37753022e-01,
5.87049790e-02, 5.38452734e-03, 2.00523865e-03, 2.88793490e-03,
6.22238404e-03, 2.05132109e-03, 3.60580256e-03, 4.32783621e-04,
5.88802289e-02, 1.48888125e-03, 1.95599123e-03, 9.33936922e-05,
3.85273142e-02, 2.82426090e-04, 1.56440433e-03, 2.19453545e-01])
확률이 가장 높은 클래스 보기
import numpy as np
np.argmax(prob, axis=1)
array([ 3, 2, 6, ..., 14, 16, 11], dtype=int64)
확률 최대값 보기
np.max(prob, axis=1)
array([0.23775302, 0.95134806, 0.86915966, ..., 0.91493761, 0.83797877,
0.91793744])
정확도 평가
model.score(x_test, y_test)
0.8020326999558108
가중치 표 만들기
import pandas as pd
wd = pd.DataFrame(model.coef_.T,
index=cv.get_feature_names_out(),
columns=news['target_names'])
자동차 게시판의 가중치가 가장 높은 단어들
wd['rec.autos'].sort_values().tail(10)
models 1.103112
early 1.103447
89 1.117528
engine 1.170994
owners 1.263317
dealer 1.273244
auto 1.439482
warning 1.603697
cars 1.895346
car 2.098276
Name: rec.autos, dtype: float64
나이브 베이즈
from sklearn.naive_bayes import BernoulliNB
model = BernoulliNB()
model.fit(x_train, y_train)
model.score(x_test, y_test)
0.7061422889969068
텐서플로
import tensorflow as tf
tf.nn.softmax([-1. , 0.5, 2. ])
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03911257, 0.17529039, 0.785597 ], dtype=float32)>
x = 2.0
tf.nn.sigmoid(x)
<tf.Tensor: shape=(), dtype=float32, numpy=0.8807971>
tf.nn.softmax([0, x])
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.11920291, 0.880797 ], dtype=float32)>
from tensorflow.keras.layers import Dense
model = tf.keras.Sequential([
Dense(20, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train.A, y_train, epochs=10)
Epoch 1/10
283/283 [==============================] - 0s 887us/step - loss: 2.0657 - accuracy: 0.5386
Epoch 2/10
283/283 [==============================] - 0s 910us/step - loss: 1.1081 - accuracy: 0.8181
Epoch 3/10
283/283 [==============================] - 0s 933us/step - loss: 0.7881 - accuracy: 0.8624
Epoch 4/10
283/283 [==============================] - 0s 899us/step - loss: 0.6315 - accuracy: 0.8847
Epoch 5/10
283/283 [==============================] - 0s 903us/step - loss: 0.5391 - accuracy: 0.8971
Epoch 6/10
283/283 [==============================] - 0s 915us/step - loss: 0.4741 - accuracy: 0.9055
Epoch 7/10
283/283 [==============================] - 0s 906us/step - loss: 0.4241 - accuracy: 0.9157
Epoch 8/10
283/283 [==============================] - 0s 919us/step - loss: 0.3871 - accuracy: 0.9213
Epoch 9/10
283/283 [==============================] - 0s 900us/step - loss: 0.3579 - accuracy: 0.9266
Epoch 10/10
283/283 [==============================] - 0s 981us/step - loss: 0.3330 - accuracy: 0.9317
<keras.callbacks.History at 0x2e6fd331c60>
model.evaluate(x_test.A, y_test)
71/71 [==============================] - 0s 879us/step - loss: 0.6993 - accuracy: 0.8188
[0.6993224620819092, 0.8188245892524719]