Page cover

Self-Training

Eine Technik im halbüberwachten Lernen ist das sogenannte Self-Training. Hierbei wird das Modell zuerst mit gelabelten Daten trainiert und dann auf ungelabelte Daten Vorhersagen getroffen. Diese Vorhersagen werden als zusätzliche "Pseudo gelabelte Daten" betrachtet, werden zu den Trainingsdaten hinzugefügt, und der Trainingsprozess wird iterativ wiederholt. Hier ist der Trainingsprozess nochmals schrittweise dargestellt:

  1. Initialisierung: Training eines Basismodells auf einem kleinen, gelabelten, Datensatz.

  2. Vorhersage: Verwendung des Basismodells, um Vorhersagen für ungelabelte Daten zu machen.

  3. Auswahl: Auswahl der ungelabelten Datenpunkte mit der höchsten Vorhersagekonfidenz.

  4. Labelzuweisung: Zuweisung des Labels (Pseudolabel) zu diesen ausgewählten Datenpunkten basierend auf den Vorhersagen des Modells.

  5. Aktualisierung: Erweiterung des ursprünglichen Trainingsdatensatzes um die neu gelabelten Daten.

  6. Wiederholung: Erneutes Training des Modells mit dem erweiterten Datensatz.

  7. Iteration: Wiederholung der Schritte 2 bis 6 für eine festgelegte Anzahl von Iterationen oder bis keine signifikante Verbesserung der Modellleistung mehr erreicht wird.

Halbüberwachtes Lernen ermöglicht es, die Vorteile sowohl von gelabelten als auch von ungelabelten Daten zu nutzen, was insbesondere in Szenarien mit begrenzten gelabelten Daten von großem Nutzen ist. Im Weiteren werden die Vor- und Nachteile des Self-Trainings erläutert:

Vorteile:

  • Effiziente Nutzung von ungelabelten Daten: Self-Training ermöglicht es, Informationen aus ungelabelten Daten zu nutzen, um die Leistung des Modells zu verbessern, ohne dass zusätzliche manuelle Annotationen erforderlich sind.

  • Erweiterung des Trainingsdatensatzes: Durch die Verwendung von Pseudolabels werden die Trainingsdaten sukzessive erweitert, was zu einer besseren Generalisierungsfähigkeit des Modells führen kann.

  • Anpassung an sich ändernde Datenverteilungen: Self-Training kann dazu beitragen, dass das Modell sich an sich ändernde Datenverteilungen anpasst, da es kontinuierlich mit neuen Daten trainiert wird.

Nachteile:

  • Risiko von Fehlerakkumulation: Da die Pseudolabels auf den Vorhersagen des Modells basieren, besteht das Risiko, dass Fehler in den Vorhersagen des Modells akkumulieren und die Qualität der erweiterten Trainingsdaten beeinträchtigen.

  • Sensitivität gegenüber Unsicherheit der Vorhersagen: Self-Training ist empfindlich gegenüber der Unsicherheit der Vorhersagen des Modells. Bei ungenauen bzw. falschen Vorhersagen nimmt die Qualität der Pseudolabels ab, was sich negativ auf die Leistung des Modells auswirkt.

In diesem Beispiel wird das Self-Training iterativ durchgeführt. Dabei wird ein Klassifikationsmodell zunächst mit gelabelten Daten trainiert. Anschließend werden Vorhersagen auf ungelabelten Daten getroffen und zuverlässige Vorhersagen für das nächste Training verwendet. Dadurch wird die Leistung des Modells schrittweise verbessert.

# Aufteilen der Daten in gelabelte und ungelabelte
X_labeled, X_unlabeled, y_labeled, _ = train_test_split(X, y)

# Definition eines einfachen Klassifikationsmodells
model = tf.keras.Sequential([...])

# Kompilieren des Modells
model.compile(...)

# Durchführung von 5 Iteratrion des Self-Trainings mit ungelabelten Daten
for self_iter in range(5):
    # Training mit gelabelten Daten
    model.fit(X_labeled, y_labeled)
    # Vorhersagen auf ungelabelten Daten
    pseudo_labels = model.predict(X_unlabeled)
    # Auswahl der zuverlässigsten Vorhersagen
    confident_predictions = np.max(pseudo_labels, axis=1) > 0.8
    # Hinzufügen der zuverlässigen Vorhersagen zu den gelabelten Daten
    X_labeled = np.concatenate([X_labeled, X_unlabeled[confident_predictions]])
    y_labeled = np.concatenate([y_labeled, np.argmax(pseudo_labels[confident_predictions], axis=1)])

    # Entfernen von Daten aus X_unlabeled, die zu X_labeled hinzugefügt werden
    X_unlabeled = X_unlabeled[~confident_predictions]

    # Evaluierung auf Testdaten
    test_accuracy = accuracy_score(y_test, np.argmax(model.predict(X_test), axis=1))

Mehr erfahren: Medium/A Gentle Introduction to Self-Training and Semi-Supervised Learning, scikit-learn/Semi-supervised learning

Last updated