Skip to content

Testing from MLflow checkpointed runs

src/bnode_core/ode/trainer_utils/test_from_mlflow.py

bnode_core.ode.trainer_utils.test_from_mlflow

Test trained models from MLflow runs on new datasets.

This module provides functionality to load trained neural ODE models from MLflow runs and test them on different datasets. It downloads artifacts directly from the MLflow server (no local file access required) and executes validation runs.

You can use all configuration options from trainer.py to override parameters for testing.

Typical Usage Example

Test a single run on a new dataset using experiment_id (recommended):

    python test_from_mlflow.py \
        experiment_id=123456789 \
        run_name=bemused-hen-59 \
        # or run_id=8c2c32b9407a4e20946f72cd1c714776 \
        dataset_name=myTestData \
        mlflow_experiment_name=validation_results \
        mlflow_tracking_uri=http://localhost:5000 \
        n_processes=1 \
        nn_model_base=latent_ode_base \
        # specify overrides:
        override nn_model.training.batch_size_test=128 \
        override nn_model.training.test_save_internal_variables=false \
        override n_workers_train_loader=1 \
        override n_workers_other_loaders=1 \
        override use_cuda=true

Command Line Arguments

Required:

dataset_name : str or list[str]
    Name(s) of dataset(s) to test on. Comma-separated for multiple datasets.
mlflow_experiment_name : str
    Name for the new MLflow experiment where test results will be logged.
nn_model_base : str
    Base configuration for the neural network model (e.g., 'latent_ode_base').

Run Selection (one of):

run_id : str or list[str]
    MLflow run ID(s) to test. Comma-separated for multiple runs.
experiment_id + run_name : str
    Experiment ID and specific run name(s) within that experiment (recommended).
experiment_id : str
    Experiment ID to test all runs from that experiment.
experiment : str (deprecated)
    Experiment name - triggers warning as multiple experiments can share names.

Optional:

mlflow_tracking_uri : str
    URI of MLflow tracking server. Defaults to 'http://localhost:5000'.
n_processes : int
    Number of parallel processes for testing. Defaults to 1 (sequential).
override <key>=<value> : 
    Override specific config parameters (e.g., 'override use_cuda=false').

Notes

  • Artifacts are downloaded from MLflow server to Hydra output directory.
  • Downloaded artifacts are stored in {hydra_output}/mlflow_test_artifacts/run_{run_id}/.
  • Artifacts persist after testing for inspection and debugging.
  • Model checkpoints are automatically retrieved from MLflow artifact storage.
  • Results are logged to a new MLflow experiment specified by mlflow_experiment_name.
  • When testing multiple runs/datasets, a Cartesian product is created (all combinations).
  • Use experiment_id instead of experiment name to avoid ambiguity.

See Also

bnode_core.ode.trainer : Main training module with train_all_phases function. mlflow : MLflow documentation for run and artifact management.

parse_overrides(override_list)

Parse list of override strings into a dict of key-value pairs. Each override should be in the form key=value.

Source code in src/bnode_core/ode/trainer_utils/test_from_mlflow.py
def parse_overrides(override_list):
    """
    Parse list of override strings into a dict of key-value pairs.
    Each override should be in the form key=value.
    """
    overrides = {}
    for item in override_list:
        if '=' not in item:
            raise ValueError(f"Override argument '{item}' is not in key=value format.")
        key, value = item.split('=', 1)
        overrides[key.strip()] = value.strip()
    return overrides

get_run_ids(args)

Retrieve MLflow run IDs based on provided selection criteria.

Resolves run IDs from either direct run_id specification, run_name lookup within an experiment, or all runs from an experiment.

Parameters:

Name Type Description Default
command_line_args dict

Parsed command line arguments containing one of:

  • 'run_id': Direct list of run IDs.
  • 'experiment_id' + 'run_name': Experiment ID and specific run names.
  • 'experiment_id': All runs from the experiment.
  • 'experiment': Experiment name (deprecated, triggers warning).

Optional:

  • 'mlflow_tracking_uri': MLflow server URI.
required

Returns:

Name Type Description
list list[str]

List of MLflow run IDs to test.

Raises:

Type Description
ValueError

If incompatible argument combinations are provided (e.g., both run_name and run_id, or both run_id and experiment_id).

Side Effects
  • Sets MLflow tracking URI via mlflow.set_tracking_uri().
  • Prints progress messages about run ID retrieval.
  • Issues warning if 'experiment' (name) is used instead of 'experiment_id'.

Examples:

>>> get_run_ids({'experiment_id': ['123456'], 'run_name': ['run1', 'run2']})
['abc123', 'def456']
Source code in src/bnode_core/ode/trainer_utils/test_from_mlflow.py
def get_run_ids(args):
    """Retrieve MLflow run IDs based on provided selection criteria.

    Resolves run IDs from either direct run_id specification, run_name lookup
    within an experiment, or all runs from an experiment.

    Args:
        command_line_args (dict): Parsed command line arguments containing one of:

            - 'run_id': Direct list of run IDs.
            - 'experiment_id' + 'run_name': Experiment ID and specific run names.
            - 'experiment_id': All runs from the experiment.
            - 'experiment': Experiment name (deprecated, triggers warning).

            Optional:

            - 'mlflow_tracking_uri': MLflow server URI.

    Returns:
        list (list[str]): List of MLflow run IDs to test.

    Raises:
        ValueError: If incompatible argument combinations are provided
            (e.g., both run_name and run_id, or both run_id and experiment_id).

    Side Effects:
        - Sets MLflow tracking URI via mlflow.set_tracking_uri().
        - Prints progress messages about run ID retrieval.
        - Issues warning if 'experiment' (name) is used instead of 'experiment_id'.

    Examples:
        >>> get_run_ids({'experiment_id': ['123456'], 'run_name': ['run1', 'run2']})
        ['abc123', 'def456']
    """
    mlflow.set_tracking_uri(args.mlflow_tracking_uri)
    if args.run_id:
        return args.run_id
    if args.experiment:
        warnings.warn(
            "Using 'experiment' (name) is deprecated. Multiple experiments can have the same name. "
            "Please use 'experiment_id' instead for unambiguous experiment identification.",
            DeprecationWarning,
            stacklevel=2
        )
        experiment = mlflow.get_experiment_by_name(args.experiment)
        if experiment is None:
            raise ValueError(f"Experiment with name '{args.experiment}' not found.")
        # TODO: What happens if there are multiple experiments with same name?
        experiment_id = experiment.experiment_id
    else:
        experiment_id = args.experiment_id
    runs = mlflow.search_runs(experiment_id)
    if args.run_name:
        run_ids = []
        for run_name in args.run_name:
            matching_runs = runs[runs["tags.mlflow.runName"] == run_name]
            if len(matching_runs) == 0:
                raise ValueError(f"No run found with name '{run_name}' in experiment {experiment_id}")
            run_ids.append(matching_runs["run_id"].values[0])
        return run_ids
    return runs["run_id"].to_list()

main()

Main execution function for testing models from MLflow runs.

Orchestrates the complete workflow: 1. Parses and validates command line arguments. 2. Retrieves run IDs from MLflow using experiment_id (or deprecated experiment name). 3. For each run-dataset combination: - Downloads artifacts from MLflow server to Hydra output directory. - Loads the training configuration from downloaded artifacts. - Updates config for testing (new dataset, model path, test mode). - Applies any override parameters. - Saves modified config to temporary files in Hydra output. 4. Executes testing jobs either sequentially or in parallel.

The function creates a Cartesian product of runs × datasets, generating one test job for each combination.

Raises:

Type Description
ValueError

If command line arguments are invalid or incompatible.

FileNotFoundError

If MLflow artifacts (config, model) cannot be found.

Side Effects
  • Creates mlflow_test_artifacts directory in Hydra output for configs and artifacts.
  • Downloads artifacts from MLflow server (if not local).
  • Artifacts persist after testing for inspection.
  • Launches subprocess calls to trainer.py for each test job.
  • Logs results to MLflow under the specified experiment name.
  • Prints progress messages throughout execution.
Notes
  • Model checkpoints are retrieved from the final training phase.
  • Sequence length is set to match the last training phase.
  • Original training dataset name is preserved for reference.
  • Artifacts are organized as: {hydra_output}/mlflow_test_artifacts/run_{run_id}/
  • Use experiment_id instead of experiment name to avoid ambiguity warnings.

Examples:

Command line usage::

python test_from_mlflow.py \
    experiment_id=123456789 \
    run_name=final-model-123 \
    dataset_name=validation_set \
    mlflow_experiment_name=validation_results \
    nn_model_base=latent_ode_base \
    n_processes=1
Source code in src/bnode_core/ode/trainer_utils/test_from_mlflow.py
def main():
    """Main execution function for testing models from MLflow runs.

    Orchestrates the complete workflow:
    1. Parses and validates command line arguments.
    2. Retrieves run IDs from MLflow using experiment_id (or deprecated experiment name).
    3. For each run-dataset combination:
       - Downloads artifacts from MLflow server to Hydra output directory.
       - Loads the training configuration from downloaded artifacts.
       - Updates config for testing (new dataset, model path, test mode).
       - Applies any override parameters.
       - Saves modified config to temporary files in Hydra output.
    4. Executes testing jobs either sequentially or in parallel.

    The function creates a Cartesian product of runs × datasets, generating
    one test job for each combination.

    Raises:
        ValueError: If command line arguments are invalid or incompatible.
        FileNotFoundError: If MLflow artifacts (config, model) cannot be found.

    Side Effects:
        - Creates mlflow_test_artifacts directory in Hydra output for configs and artifacts.
        - Downloads artifacts from MLflow server (if not local).
        - Artifacts persist after testing for inspection.
        - Launches subprocess calls to trainer.py for each test job.
        - Logs results to MLflow under the specified experiment name.
        - Prints progress messages throughout execution.

    Notes:
        - Model checkpoints are retrieved from the final training phase.
        - Sequence length is set to match the last training phase.
        - Original training dataset name is preserved for reference.
        - Artifacts are organized as: {hydra_output}/mlflow_test_artifacts/run_{run_id}/
        - Use experiment_id instead of experiment name to avoid ambiguity warnings.

    Examples:
        Command line usage::

            python test_from_mlflow.py \\
                experiment_id=123456789 \\
                run_name=final-model-123 \\
                dataset_name=validation_set \\
                mlflow_experiment_name=validation_results \\
                nn_model_base=latent_ode_base \\
                n_processes=1
    """
    args = parse_args()
    overrides = parse_overrides(args.override)
    run_ids = get_run_ids(args)
    dataset_names = args.dataset_name

    # get temporary directory in Hydra output folder for saving config files and artifacts
    now = datetime.now()
    date_str = now.strftime("%Y-%m-%d")
    time_str = now.strftime("%H-%M-%S")
    hydra_output_dir = Path.cwd() / "outputs" / date_str / time_str
    hydra_output_dir.mkdir(parents=True, exist_ok=True)
    temp_dir_path = hydra_output_dir / "mlflow_test_artifacts"
    temp_dir_path.mkdir(parents=True, exist_ok=True)
    print(f"using directory for artifacts and config files: {temp_dir_path}")
    # create nn_model directory
    (temp_dir_path / "nn_model").mkdir(parents=True, exist_ok=True)

    jobs = 0
    training_dataset_list = []
    artifact_dirs = {}  # Store artifact directories per run_id

    for run_id in run_ids:
        for dataset in dataset_names:
            # Download artifacts from MLflow server
            mlflow_run = mlflow.get_run(run_id)
            artifact_uri = mlflow_run.info.artifact_uri

            # Check if we need to download artifacts (not already local)
            if run_id not in artifact_dirs:
                if not artifact_uri.startswith('file://'):
                    # Download artifacts from remote MLflow server
                    run_artifact_dir = temp_dir_path / f"run_{run_id}"
                    run_artifact_dir.mkdir(parents=True, exist_ok=True)
                    print(f"Downloading artifacts for run {run_id} from MLflow server to {run_artifact_dir}")

                    mlflow.artifacts.download_artifacts(
                        run_id=run_id,
                        dst_path=str(run_artifact_dir)
                    )
                    artifact_dirs[run_id] = run_artifact_dir
                    print(f"Successfully downloaded artifacts to {run_artifact_dir}")
                else:
                    # Local artifacts - use direct path
                    artifact_dirs[run_id] = Path(artifact_uri.replace('file://',''))
                    print(f"Using local artifacts from {artifact_dirs[run_id]}")

            # Get config path from downloaded/local artifacts
            _config_path = artifact_dirs[run_id] / ".hydra" / "config_validated.yaml"

            if not _config_path.exists():
                raise FileNotFoundError(f"Config file not found at {_config_path}")

            # copy config to temporary directory
            temp_config = temp_dir_path / f"config_{jobs}.yaml"
            print(f"copying config from {_config_path} to {temp_config}")
            shutil.copy(_config_path, temp_config)

            # update config with dataset_name, sequence_length, model_path, test_mode
            # open yaml file
            with open(temp_config) as file:
                config = yaml.load(file, Loader=yaml.FullLoader)

            # update config
            training_dataset_list.append(config["dataset_name"])
            config["dataset_name"] = dataset
            config["mlflow_experiment_name"] = args.mlflow_experiment_name
            # TODO: is this important
            config["nn_model"]["training"]["pre_trained_model_seq_len"] = config["nn_model"]["training"]["main_training"][-1]["seq_len_train"]

            # Use local path to downloaded model checkpoint
            # TODO: could also look for latest checkpoint in dir
            _model_filename = f"model_phase_{len(config['nn_model']['training']['main_training'])}.pt"
            _model_path = artifact_dirs[run_id] / _model_filename

            if not _model_path.exists():
                raise FileNotFoundError(f"Model checkpoint not found at {_model_path}")

            config["nn_model"]["training"]["path_trained_model"] = str(_model_path)
            print(f"\tmodel path: {_model_path}")
            config["nn_model"]["training"]["load_trained_model_for_test"] = True

            # set overrides (propagate to config using key-path logic)
            for key, value in overrides.items():
                print(f"setting override: {key}={value}")
                path = key.split('.')
                ref = config
                for i, part in enumerate(path):
                    if i == len(path) - 1:
                        ref[part] = value
                    else:
                        if part not in ref:
                            ref[part] = {}
                        ref = ref[part]

            config_nn_model = config["nn_model"]
            # delete nn_model from config
            del config["nn_model"]

            # add defaults to the very top
            config["defaults"]= ["base_train_test", {"nn_model": f"model{jobs}"}, "_self_"]
            # also to nn_model
            config_nn_model["defaults"] = [args.nn_model_base, "_self_"]

            # save updated config to temporary directory
            with open(temp_config, 'w') as file:
                yaml.dump(config, file)

            # save nn_model to temporary directory
            with open(temp_dir_path / "nn_model" / f"model{jobs}.yaml", 'w') as file:
                yaml.dump(config_nn_model, file)
            jobs += 1
    print(f"Successfully created {jobs} jobs.")
    print(f"Artifacts and configs saved in: {temp_dir_path}")

    print("starting jobs...")
    # remove all arguments from sys.argv

    def wrap_train_all_phases(temp_dir, temp_config_name, training_dataset):
        """Execute trainer.py in a subprocess with specified configuration.

        Args:
            temp_dir (str): Path to temporary directory containing config files.
            temp_config_name (str): Name of the config file to use.
            training_dataset (str): Name of the original training dataset (for logging).

        Returns:
            subprocess.CompletedProcess: Result of the subprocess execution.
        """
        # Split the command into executable and arguments instead of one string with spaces.
        cmd = [
            "uv",
            "run",
            "trainer",
            f"-cp={temp_dir}",
            f"-cn={temp_config_name}",
            f"+nn_model.training.training_dataset_name={training_dataset}"
        ]
        print(f"Running command: {' '.join(cmd)}")
        result = subprocess.run(cmd)
        return result

    if args.n_processes == 1:
        print("running jobs sequentially")
        for i in range(jobs):
            result = wrap_train_all_phases(str(temp_dir_path), f"config_{i}.yaml", training_dataset_list[i])
            print(result)
    else:
        print("running jobs in parallel")
        warnings.warn("Parallel execution is not fully tested yet.")
        with Pool(processes=args.n_processes) as pool:
            results = [pool.apply_async(wrap_train_all_phases, (str(temp_dir_path), f"config_{i}.yaml", training_dataset_list[i])) for i in range(jobs)]
            for result in results:
                result.get()

    print(f"All jobs completed.")