본문 바로가기
Library

sklearn - template

by wycho 2021. 7. 1.
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_squared_log_error, log_loss
from sklearn.metrics import accuracy_score, roc_curve, auc, r2_score
from sklearn.metrics import classification_report, confusion_matrix

from sklearn.preprocessing import Normalizer, RobustScaler, MinMaxScaler, StandardScaler, MaxAbsScaler
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor

import sys


def k_fold_CV(model, X_train, X_test, y_train, y_test, n_folds=3):
    kf = KFold(n_splits=n_folds,shuffle=False)#, random_state=0)
    
    train_fold_pred = np.zeros((X_train.shape[0] ,1 ))
    test_pred = np.zeros((X_test.shape[0],n_folds))
    
    for folder_counter , (train_index, valid_index) in enumerate(kf.split(X_train)):
        X_tr = X_train[train_index]
        y_tr = y_train[train_index]
        X_te = X_train[valid_index]
        
        model.fit(X_tr , y_tr)
        
        train_fold_pred[valid_index, :] = model.predict(X_te).reshape(-1,1)
        test_pred[:, folder_counter]    = model.predict(X_test)
        
        
    y_pred = np.mean(test_pred, axis=1).reshape(-1)
    
    acc = accuracy_score(y_test, y_pred.round(0))
    mse = mean_squared_error(y_test, y_pred)

    return acc, mse



def preprocessing(X,y):
    scaler = MaxAbsScaler()
    X = scaler.fit_transform(X)

    X = X.values
    y = y.values
    
    return X, y



def main(fname):
    df = pd.read_table(fname,sep='\t')

    X = df.iloc[:,1:]
    y = df.iloc[:,0]

    cols = X.columns
    
    X, y = preprocessing(X,y)
    
    X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2)

    model = RandomForestClassifier(n_jobs=6)
    
    acc, mse = k_fold_CV(model, X_train, X_test, y_train, y_test, 4)



if __name__ == '__main__':
    try:
        fname = sys.argv[1]
    except:
        exit()
        
    main(fname)

template.py
0.00MB

 

 

'Library' 카테고리의 다른 글

numpy - ravel_multi_index  (0) 2021.12.21
sklearn - Scaler  (0) 2021.06.23
Scikit-allel  (0) 2020.11.06
sklearn - Standardization  (0) 2020.11.05
Scikit-learn, sklearn  (0) 2020.11.04

댓글