Keras는 다중 레이블 분류를 어떻게 처리합니까?
다음 상황에서 Keras의 기본 동작을 해석하는 방법을 잘 모르겠습니다.
내 Y (실측)는 scikit-learn의 MultilabelBinarizer
()를 사용하여 설정되었습니다 .
따라서 임의의 예를 제공하기 위해 내 y
열의 한 행은 다음과 같이 원-핫 인코딩됩니다 [0,0,0,1,0,1,0,0,0,0,1]
.
그래서 저는 예측할 수있는 11 개의 클래스를 가지고 있고, 하나 이상이 사실 일 수 있습니다. 따라서 문제의 다중 레이블 특성입니다. 이 특정 샘플에 대한 세 가지 레이블이 있습니다.
다중 레이블이 아닌 문제 (평소와 같은 비즈니스)에 대해 하듯이 모델을 교육하고 오류가 발생하지 않습니다.
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
model = Sequential()
model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
model.add(Dropout(0.1))
model.add(Dense(600, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(y_train.shape[1], activation='softmax'))
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
optimizer=sgd,
metrics=['accuracy',])
model.fit(X_train, y_train,epochs=5,batch_size=2000)
score = model.evaluate(X_test, y_test, batch_size=2000)
score
Keras는 my를 만나고 y_train
"멀티"원-핫 인코딩 된 것을 볼 때 무엇 을 y_train
합니까? 즉 , 각 행에 하나 이상의 '하나'가 있음을 의미합니다 . 기본적으로 Keras는 다중 레이블 분류를 자동으로 수행합니까? 스코어링 메트릭의 해석에 차이가 있습니까?
간단히 말해서
사용하지 마십시오 softmax
.
sigmoid
출력 레이어 활성화에 사용 합니다.
binary_crossentropy
손실 기능에 사용 합니다.
predict
평가에 사용 합니다.
왜
에서 softmax
하나 개의 레이블에 대한 점수를 증가 할 때, 다른 모든 (이 확률 분포의) 인하된다. 레이블이 여러 개인 경우 원하지 않습니다.
완전한 코드
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
model = Sequential()
model.add(Dense(5000, activation='relu', input_dim=X_train.shape[1]))
model.add(Dropout(0.1))
model.add(Dense(600, activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(y_train.shape[1], activation='sigmoid'))
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy',
optimizer=sgd)
model.fit(X_train, y_train, epochs=5, batch_size=2000)
preds = model.predict(X_test)
preds[preds>=0.5] = 1
preds[preds<0.5] = 0
# score = compare preds and y_test
참조 URL : https://stackoverflow.com/questions/44164749/how-does-keras-handle-multilabel-classification
'Nice programing' 카테고리의 다른 글
Android Studio 여러 라이브러리 프로젝트에서 단일 AAR을 패키징하는 방법은 무엇입니까? (0) | 2021.01.08 |
---|---|
memcpy ()의 속도가 4KB마다 급격하게 떨어지는 이유는 무엇입니까? (0) | 2021.01.08 |
i = i + n이 i + = n과 정말 동일합니까? (0) | 2021.01.07 |
자바에서 프로그래밍 방식 교착 상태 감지 (0) | 2021.01.07 |
언제 쿠키 대신 세션 변수를 사용해야합니까? (0) | 2021.01.07 |