機械学習のモデルを作成する際、手元のデータをすべて学習に使ってしまうと、未知のデータに対してどの程度正確に予測できるのかを正しく評価できません。
この記事では、Pythonの代表的な機械学習ライブラリであるScikit-learn(サイキットラーン)に用意されている関数、train_test_splitの使い方を徹底的に解説します。
この記事を読むことで、以下のことが解決できます。
- データを学習用(訓練用)とテスト用に分割する理由がわかる
train_test_splitの基本的な書き方とコピペで動くコードが手に入るtest_sizeやrandom_state、stratifyなどの重要パラメータの意味が理解できる- 実践的なPandasデータフレームでの分割方法がマスターできる
機械学習の第一歩として絶対に避けては通れない関数ですので、この記事を通してしっかりと基礎を固めましょう。本記事で解説している内容の公式ドキュメントはこちらです。
結論:train_test_splitはモデル評価に必須の関数です
結論から言うと、train_test_splitは、手元のデータを「モデルに学習させるためのデータ」と「モデルの性能をテストするためのデータ」に分割するために使用する関数です。
機械学習において、このデータ分割はモデルの汎化性能(未知のデータに対する対応力)を測るために絶対に欠かせない工程となります。
なぜデータを学習用とテスト用に分けるのか?
手元のデータを学習用とテスト用に分ける最大の理由は、「過学習(オーバーフィッティング)」を防ぎ、モデルの真の実力を評価するためです。
過学習とは、モデルが学習データにだけ過剰に適合してしまい、新しく入力されたデータに対しては全く予測が当たらない状態のことを指します。これは、過去問の答えを丸暗記しただけで、本番の試験で応用問題が解けない学生に似ています。
過学習が起きていないかを確認するためには、モデルが一度も見たことのない「テストデータ」を使ってテストを行う必要があります。そのために、最初の段階でデータを切り分けておくのです。
Scikit-learnのバージョンについて
本記事で紹介するtrain_test_splitの機能は、Scikit-learnのバージョンに関わらず広く一般的に利用できるものです。
本記事のコードは、Scikit-learn 1.0以降の現行バージョンであれば問題なく動作します。もしエラーが出る場合は、ライブラリのバージョンアップも検討してみてください。
train_test_splitの基本的な使い方とコード例
結論から言うと、train_test_splitはsklearn.model_selectionからインポートし、特徴量と目的変数を引数に渡すだけで簡単にデータを分割できます。
ここでは、最もシンプルで基本的な使い方をコード付きで解説します。
インポート方法
まずは関数をインポートする必要があります。以下のコードを実行して準備しましょう。
# train_test_split関数のインポート
from sklearn.model_selection import train_test_split
# 数値計算用とデータ操作用のライブラリもインポート
import numpy as np
import pandas as pdこのインポート文は、機械学習のプログラムを書く際のお決まりのフレーズのようなものですので、そのまま覚えてしまいましょう。
最もシンプルな分割コード
次に、ダミーのデータを用意して、実際にデータを分割してみましょう。ここでは10個のサンプルデータを作成します。
# ダミーデータの作成
# X: 特徴量(0から9までの配列を2次元に変換)
X = np.arange(10).reshape((5, 2))
# y: 目的変数(正解ラベル。ここでは0か1)
y = np.array([0, 1, 0, 1, 0])
# データを学習用とテスト用に分割
X_train, X_test, y_train, y_test = train_test_split(X, y)
print("X_train (学習用特徴量):\n", X_train)
print("X_test (テスト用特徴量):\n", X_test)
print("y_train (学習用目的変数):", y_train)
print("y_test (テスト用目的変数):", y_test)このコードを実行すると、元のデータXとyがランダムにシャッフルされ、デフォルトの設定(通常は学習用75%、テスト用25%)で分割されます。
戻り値の順番に注意しよう
train_test_splitを使用する上で、初心者が最もつまずきやすいのが「戻り値を受け取る変数の順番」です。
結論として、必ず X_train, X_test, y_train, y_test の順番で変数を受け取るようにしてください。
関数は、引数に渡されたデータ(今回はXとy)をそれぞれ分割し、リスト形式で返します。順番を間違えると、特徴量と目的変数が入り混じってしまい、後のモデル学習で重大なエラーを引き起こす原因となります。スマホの予測変換などに頼らず、この順番はしっかりと指に覚えさせましょう。
絶対に覚えておきたい主要パラメータ5選
結論から言うと、実務でtrain_test_splitを使う場合、引数(パラメータ)を適切に設定することで、分割の割合やランダム性をコントロールする必要があります。
ここでは、非常によく使う5つの重要なパラメータについて詳しく解説します。
test_size / train_size(分割割合の指定)
test_sizeはテストデータの割合を、train_sizeは学習データの割合を指定するパラメータです。
通常はtest_sizeのみを指定することが多いです。実数(0.0〜1.0)を指定すると全体の割合となり、整数を指定するとデータの件数となります。
# 全体の20%をテストデータ、80%を学習データにする場合
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2 # 0.2は20%という意味
)機械学習では、学習データが多い方がモデルの精度が上がりやすいため、test_sizeは0.2〜0.3(20%〜30%)に設定するのが一般的です。
random_state(乱数シードの固定)
結論から言うと、random_stateはデータの分割結果を固定するために必ず設定すべきパラメータです。
デフォルトでは、コードを実行するたびにデータの分割(シャッフル)がランダムに行われます。しかし、これでは「モデルの精度が上がったのは、アルゴリズムを改善したからなのか、たまたま分割されたデータが良かっただけなのか」が判断できなくなります。
# random_stateに任意の整数(例:42)を設定して結果を固定する
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42 # シード値を固定
)数字自体(42や0など)に深い意味はありませんが、同じ数字を設定すれば、誰がいつ実行しても全く同じようにデータが分割されるようになります。再現性を保つために非常に重要です。
shuffle(データのシャッフル)
shuffleは、分割する前にデータをランダムに並び替えるかどうかを指定するパラメータです。
デフォルトではshuffle=True(シャッフルする)になっています。多くの場合はシャッフルした方が良い結果を得られますが、株価などの「時系列データ」を扱う場合は注意が必要です。
# 時系列データなど、順番に意味がある場合はシャッフルをオフにする
X_train, X_test, y_train, y_test = train_test_split(
X, y,
shuffle=False # シャッフルせずにそのままの順番で分割
)時系列データで未来のデータを使って過去を予測するのは不正解であるため、過去から順番にデータを分割しなければなりません。そのようなケースでは必ずshuffle=Falseを設定しましょう。
stratify(層化抽出によるクラス割合の保持)
結論として、分類問題を解く場合は、このstratifyパラメータを必ず設定する癖をつけてください。
データの中に「犬の画像が90枚、猫の画像が10枚」のような偏り(不均衡データ)がある場合、ランダムに分割すると、テストデータに猫の画像が1枚も含まれないといった事態が起こり得ます。
stratify=yと設定することで、元のデータのクラス割合(正解ラベルの比率)を保ったまま分割してくれます。
# 正解ラベル(y)の割合を保ったまま分割する
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42,
stratify=y # y(目的変数)の比率を維持
)これについては、後ほど「分類問題で必須!stratify(層化抽出)の重要性」の章でさらに詳しく解説します。
実践編!Pandas DataFrameを使ったデータ分割
結論から言うと、実務ではNumPy配列よりも、PandasのDataFrame(データフレーム)を使ってデータを分割するケースが圧倒的に多いです。
ここでは、Kaggleなどのデータ分析コンペでもよく使われる、実践的なPandasを用いたデータの分割方法を解説します。
特徴量(X)と目的変数(y)の分離
CSVファイルなどから読み込んだデータは、通常、特徴量(予測の手がかりとなるデータ)と目的変数(予測したい答え)が一つのテーブルにまとまっています。
まずはこれを、Pandasの機能を使ってXとyに切り離します。
# 架空の住宅価格データセット(DataFrame)を作成
data = {
'広さ': [50, 60, 70, 80, 90, 100],
'築年数': [5, 10, 3, 20, 15, 2],
'価格': [3000, 2500, 4000, 1500, 2000, 5000] # これが予測したい目的変数
}
df = pd.DataFrame(data)
# 目的変数(価格)を y に格納
y = df['価格']
# 特徴量(価格以外のすべての列)を X に格納
# dropメソッドを使って、'価格'列を削除したデータフレームを取得
X = df.drop('価格', axis=1)このように、dropメソッドを使うと簡単に特徴量だけのデータフレームを作成できます。
分割後のデータ確認方法
Xとyを準備できたら、先ほどと同じようにtrain_test_splitを実行します。分割後のデータのサイズ(行数と列数)を確認する癖をつけておきましょう。
# データの分割(テストデータを約33%に設定)
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.33,
random_state=123
)
# shape属性を使って、分割後の行数と列数を確認
print("学習用 特徴量のサイズ:", X_train.shape)
print("テスト用 特徴量のサイズ:", X_test.shape)shape属性を使うことで、意図した通りにデータが分割されているか(例えば6行のデータが4行と2行に分かれているか)をサッと確認できます。エラーを未然に防ぐための重要なテクニックです。
分類問題で必須!stratify(層化抽出)の重要性
結論として、データに偏りがある分類問題においてstratifyを使わないと、モデルの評価が全く当てにならなくなる危険性があります。
ここでは、前述したstratifyパラメータについて、その重要性をさらに深掘りして解説します。
不均衡データにおける問題点
現実社会のデータは、綺麗なバランスになっていることの方が稀です。例えば、「工場で不良品を検知するAI」を作る場合、正常な製品が99%で、不良品はわずか1%しかないといった状況がよくあります。
このようなデータを単にランダムに分割(stratifyなし)してしまうとどうなるでしょうか。
最悪の場合、テストデータの中に「不良品のデータが1件も含まれない」という事態が発生します。そのテストデータでモデルを評価しても、「すべて正常品だと予測するだけの役に立たないAI」を「精度100%の素晴らしいモデル」だと勘違いしてしまうことになります。
stratifyパラメータの効果的な使い方
このような事故を防ぐのが「層化抽出(Stratified Sampling)」という手法であり、Scikit-learnではstratify=yと記述するだけで簡単に実現できます。
# 極端に偏った目的変数(例:0が90個、1が10個)を想定
y_imbalanced = np.array([0]*90 + [1]*10)
X_dummy = np.zeros((100, 2)) # 形を合わせるためのダミーデータ
# stratifyを指定して分割
X_tr, X_te, y_tr, y_te = train_test_split(
X_dummy, y_imbalanced,
test_size=0.2,
random_state=0,
stratify=y_imbalanced # ★ここで目的変数を指定する!
)
# 分割後のテストデータ内のラベルの割合を確認
unique, counts = np.unique(y_te, return_counts=True)
print("テストデータのラベル内訳:", dict(zip(unique, counts)))
# 出力結果は {0: 18, 1: 2} となり、元の 9:1 の割合が綺麗に維持されます。このように、分類タスク(カテゴリを予測するタスク)に取り組む際は、データフレームを扱う際も必ずstratify=yを含めるように習慣づけましょう。※ただし、数値を予測する回帰問題(株価や気温の予測など)ではstratifyは使用できないので注意してください。
初心者が陥りやすい!よくあるエラーと解決策
結論から言うと、train_test_splitで発生するエラーの9割は、入力するデータのサイズ(行数)が合っていないことが原因です。
ここでは、よく遭遇する代表的なエラーとその対処法をまとめました。
ValueError: Found input variables with inconsistent numbers of samples
このエラーメッセージは、「入力されたX(特徴量)とy(目的変数)のサンプル数(行数)が一致していませんよ」という意味です。
原因と対策: データを前処理する段階で、欠損値(NaN)を含む行をdropna()で削除した際などに、X側だけ削除してy側を削除し忘れたりすると、このエラーが発生します。
必ず、train_test_splitに渡す直前のX.shape[0](Xの行数)とy.shape[0](yの行数)が完全に一致しているかをprint関数などで確認してください。
ModuleNotFoundError: No module named ‘sklearn’
このエラーは、そもそもScikit-learnのライブラリがPython環境にインストールされていない場合に発生します。
原因と対策: ターミナルやコマンドプロンプト(Jupyter Notebookの場合はセル内)で、以下のコマンドを実行してインストールを行ってください。
# pipを使用する場合
pip install scikit-learn
# Anacondaを使用する場合
conda install scikit-learn注意点として、インストール時のパッケージ名はscikit-learnですが、Pythonのコード内でインポートする際はsklearnと記述します。この違いも初心者が戸惑いやすいポイントです。
応用編:train_test_splitと交差検証(Cross-Validation)の違い
結論として、train_test_splitは素早くモデルを評価したい時に使い、交差検証(Cross-Validation)はより厳密で信頼性の高い評価を行いたい時に使います。
機械学習を少し深く学び始めると、「交差検証(クロスバリデーション)」という言葉を耳にするようになります。ここでは両者の使い分けについて解説します。
いつtrain_test_splitを使うべきか?
train_test_splitは、データを1回だけ分割する手法です(ホールドアウト法とも呼ばれます)。
- メリット: 処理が非常に高速で、コードがシンプル。
- デメリット: 分割のされ方(乱数のシード値など)によって、モデルの評価スコアが上振れしたり下振れしたりする可能性がある。
- 適したケース: データ量が膨大で学習に時間がかかる場合や、まずはざっくりとモデルの性能を確認したいプロトタイピングの段階。
交差検証が必要なケースとは?
交差検証(特にk分割交差検証、K-Fold)は、データを複数個に分割し、「学習とテスト」のプロセスをデータの組み合わせを変えながら複数回繰り返す手法です。
- メリット: データの偏りによる評価のブレが少なくなり、モデルの汎化性能をより正確に、客観的に評価できる。
- デメリット: 学習とテストを複数回(例えば5分割なら5回)繰り返すため、計算に時間がかかる。
- 適したケース: データ量が少ない場合や、最終的なモデルの精度を厳密にコンペティション等で競う場合。
実務では、まずtrain_test_splitで手早くベースライン(基準となる精度のモデル)を作成し、その後、ハイパーパラメータのチューニングを行う段階で交差検証(GridSearchCVなど)を導入する、というハイブリッドなアプローチがよく取られます。
まとめ:train_test_splitをマスターして機械学習の基礎を固めよう
この記事では、Scikit-learnのtrain_test_split関数について、基本的な使い方から実務で役立つパラメータ設定まで詳しく解説しました。
重要なポイントを最後にもう一度まとめます。
- 目的: 過学習を防ぎ、モデルの未知のデータに対する予測性能(汎化性能)を正しく評価するためにデータを分割する。
- 戻り値: 必ず
X_train, X_test, y_train, y_testの順番で受け取る。 - 再現性:
random_stateを設定して、結果を固定する。 - 分類問題の鉄則: データの偏りを防ぐために
stratify=yを必ず指定する。
train_test_splitは、PythonでAIや機械学習のプログラムを書く際、まるで挨拶のように毎回登場する必須の関数です。使い方を暗記するくらい、何度も手を動かしてコードを書いてみてください。
機械学習のデータ前処理やモデル評価についてさらに深く学びたい方は、ぜひScikit-learnの公式ドキュメントも合わせて読んでみることをおすすめします。


コメント