Archived Documentation

Welcome to the developer documentation for SigOpt. If you have a question you can’t answer, feel free to contact us!
You are currently viewing archived SigOpt documentation. The newest documentation can be found here.
This feature is currently in alpha. Please contact us if you would like more information.

Java Training Monitor Example

This is an example of a training monitor experiment in Java that uses convergence criteria to stop early.

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Observable;
import java.util.Random;
import com.sigopt.Sigopt;
import com.sigopt.exception.APIException;
import com.sigopt.exception.SigoptException;
import com.sigopt.model.*;

public class App
{
  static Random r = new Random(1);

  public static void main(String[] args) throws APIException, InterruptedException, SigoptException
  {
    String clientId = null;
    for (int i = 0; i < args.length; i++) {
      if (args[i].equals("--client_token") && i < args.length - 1) {
        Sigopt.clientToken = args[i+1];
      }
    }

    // Create Training Monitor Experiment
    Experiment experiment = Experiment.create(
      new Experiment.Builder()
        .name("Java Sigopt Training Monitor Example")
        .parameters(Arrays.asList(
          new Parameter.Builder()
            .name("x")
            .type("double")
            .bounds(new Bounds(0.0, 1.0))
            .build(),
          new Parameter.Builder()
            .name("y")
            .type("double")
            .bounds(new Bounds(0.0, 1.0))
            .build()
        ))
        .metrics(Arrays.asList(
          new Metric.Builder()
            .name("accuracy")
            .build()
        ))
        .trainingMonitor(
          new TrainingMonitor.Builder()
            .maxCheckpoints(10)
            .earlyStoppingCriteria(Arrays.asList(
              new TrainingMonitorStoppingCriteria.Builder()
                .type("convergence")
                .name("convergence criteria")
                .metric("accuracy")
                .lookbackCheckpoints(1)
                .minCheckpoints(3)
                .build()
            ))
            .build()
        )
        .observationBudget(20)
        .build())
      .call();

    System.out.println("Created experiment: https://app.sigopt.com/experiment/" + experiment.getId());

    // Main Loop
    for (int i = 0; i < experiment.getObservationBudget(); i++) {
      Suggestion suggestion = experiment.suggestions().create().call();
      TrainingRun training_run = experiment.trainingRuns().create().data(new TrainingRun.Builder()
        .suggestion(suggestion.getId())
        .build()
      )
      .call();

      for (int j = 0; j < experiment.getTrainingMonitor().getMaxCheckpoints(); j++) {
        double value = evaluateModel(suggestion.getAssignments(), j + 1);
        Checkpoint checkpoint = training_run.checkpoints().create().data(new Checkpoint.Builder()
          .values(Arrays.asList(
            new MetricEvaluation.Builder()
              .name("accuracy")
              .value(value)
              .build()
          ))
          .build()
        )
        .call();

        // Early Stopping based on convergence criteria
        if(checkpoint.getShouldStop()) break;
      }

      Observation observation = experiment.observations().create().data(
        new Observation.Builder().trainingRun(training_run.getId()).build()
      )
      .call();

    }
  }

  // Creates fake quasi-realistic training run data
  public static double evaluateModel(Assignments assignments, int checkpoint){
    double x = assignments.getDouble("x");
    double y = assignments.getDouble("y");
    return ( .1 - .1*x + .2*y + Math.log((double)checkpoint) / Math.log(100) ) * ( 1 + .05*r.nextGaussian() );
  }
}