Skip to main content
Version: Next

ML Model

The ML Model entity represents trained machine learning models across various ML platforms and frameworks. ML Models can be trained using different algorithms and frameworks (TensorFlow, PyTorch, Scikit-learn, etc.) and deployed to various platforms (MLflow, SageMaker, Vertex AI, etc.).

Identity

ML Models are identified by three pieces of information:

  • The platform where the model is registered or deployed: this is the specific ML platform that hosts or manages this model. Examples are mlflow, sagemaker, vertexai, databricks, etc. See dataplatform for more details.
  • The name of the model: this is the unique identifier for the model within the platform. The naming convention varies by platform:
    • MLflow: typically uses the registered model name (e.g., recommendation-model)
    • SageMaker: uses the model name or model package group name (e.g., product-recommendation-v1)
    • Vertex AI: uses the model resource name (e.g., projects/123/locations/us-central1/models/456)
  • The environment or origin where the model was trained: this is similar to the fabric concept for datasets, allowing you to distinguish between models in different environments (PROD, DEV, QA, etc.). The full list of supported environments is available in FabricType.pdl.

An example of an ML Model identifier is urn:li:mlModel:(urn:li:dataPlatform:mlflow,my-recommendation-model,PROD).

Important Capabilities

Basic Model Information

The core information about an ML Model is captured in the mlModelProperties aspect. This includes:

  • Name and Description: Human-readable name and description of what the model does
  • Model Type: The algorithm or architecture used (e.g., "Convolutional Neural Network", "Random Forest", "BERT")
  • Version: Version information using the versionProperties aspect
  • Timestamps: Created and last modified timestamps
  • Custom Properties: Flexible key-value pairs for platform-specific metadata (e.g., framework version, model format)

The following code snippet shows you how to create a basic ML Model:

Python SDK: Create an ML Model
from datahub.metadata.urns import MlModelGroupUrn
from datahub.sdk import DataHubClient
from datahub.sdk.mlmodel import MLModel

client = DataHubClient.from_env()

mlmodel = MLModel(
id="customer-churn-predictor",
name="Customer Churn Prediction Model",
platform="mlflow",
description="A gradient boosting model that predicts customer churn based on usage patterns and engagement metrics",
custom_properties={
"framework": "xgboost",
"framework_version": "1.7.0",
"model_format": "pickle",
},
model_group=MlModelGroupUrn(platform="mlflow", name="customer-churn-models"),
)

client.entities.upsert(mlmodel)

Hyperparameters and Metrics

ML Models can capture both the hyperparameters used during training and various metrics from training and production:

  • Hyperparameters: Configuration values that control the training process (learning rate, batch size, number of epochs, etc.)
  • Training Metrics: Performance metrics from the training process (accuracy, loss, F1 score, etc.)
  • Online Metrics: Performance metrics from production deployment (latency, throughput, drift, etc.)

These are stored in the mlModelProperties aspect as structured lists of parameters and metrics.

Python SDK: Add hyperparameters and metrics to an ML Model
from datahub.metadata.urns import CorpUserUrn, DomainUrn, MlModelUrn, TagUrn
from datahub.sdk import DataHubClient

client = DataHubClient.from_env()

mlmodel = client.entities.get(
MlModelUrn(platform="mlflow", name="customer-churn-predictor")
)

mlmodel.set_hyper_params(
{
"learning_rate": "0.1",
"max_depth": "6",
"n_estimators": "100",
"subsample": "0.8",
"colsample_bytree": "0.8",
}
)

mlmodel.set_training_metrics(
{
"accuracy": "0.87",
"precision": "0.84",
"recall": "0.82",
"f1_score": "0.83",
"auc_roc": "0.91",
}
)

mlmodel.add_owner(CorpUserUrn("data_science_team"))

mlmodel.add_tag(TagUrn("production"))
mlmodel.add_tag(TagUrn("classification"))

mlmodel.set_domain(DomainUrn("urn:li:domain:customer-analytics"))

client.entities.update(mlmodel)

Intended Use and Ethical Considerations

DataHub supports comprehensive model documentation following ML model card best practices. These aspects help stakeholders understand the appropriate use cases and ethical implications of using the model:

  • Intended Use (intendedUse aspect): Documents primary use cases, intended users, and out-of-scope applications
  • Ethical Considerations (mlModelEthicalConsiderations aspect): Documents use of sensitive data, risks and harms, mitigation strategies
  • Caveats and Recommendations (mlModelCaveatsAndRecommendations aspect): Additional considerations, ideal dataset characteristics, and usage recommendations

These aspects align with responsible AI practices and help ensure models are used appropriately.

Training and Evaluation Data

ML Models can document their training and evaluation datasets in two complementary ways:

Direct Dataset References

  • Training Data (mlModelTrainingData aspect): Datasets used to train the model, including preprocessing information and motivation for dataset selection
  • Evaluation Data (mlModelEvaluationData aspect): Datasets used for model evaluation and testing

Each dataset reference includes the dataset URN, motivation for using that dataset, and any preprocessing steps applied. This creates direct lineage relationships between models and their training data.

Lineage via Training Runs

Training runs (dataProcessInstance entities) provide an alternative and often more detailed way to capture training lineage:

  • Training runs declare their input datasets via dataProcessInstanceInput aspect
  • Training runs declare their output datasets via dataProcessInstanceOutput aspect
  • Models reference training runs via the trainingJobs field

This creates indirect lineage: Dataset → Training Run → Model

When to use each approach:

  • Use direct dataset references for simple documentation of what data was used
  • Use training runs for complete lineage tracking including:
    • Multiple training/validation/test datasets
    • Metrics and hyperparameters from the training process
    • Temporal tracking (when the training occurred)
    • Connection to experiments for comparing multiple training attempts

Most production ML systems should use training runs for comprehensive lineage tracking.

Factor Prompts and Quantitative Analysis

For detailed model analysis and performance reporting:

  • Factor Prompts (mlModelFactorPrompts aspect): Factors that may affect model performance (demographic groups, environmental conditions, etc.)
  • Quantitative Analyses (mlModelQuantitativeAnalyses aspect): Links to dashboards or reports showing disaggregated performance metrics across different factors
  • Metrics (mlModelMetrics aspect): Detailed metrics with descriptions beyond simple training/online metrics

Source Code and Cost

  • Source Code (sourceCode aspect): Links to model training code, notebooks, or repositories (GitHub, GitLab, etc.)
  • Cost (cost aspect): Cost attribution information for tracking model training and inference expenses

Training Runs and Experiments

ML Models in DataHub can be linked to their training runs and experiments, providing complete lineage from raw data through training to deployed models.

Training Runs

Training runs represent specific executions of model training jobs. In DataHub, training runs are modeled as dataProcessInstance entities with a specialized subtype:

  • Entity Type: dataProcessInstance
  • Subtype: MLAssetSubTypes.MLFLOW_TRAINING_RUN
  • Key Aspects:
    • dataProcessInstanceProperties: Basic properties like name, timestamps, and custom properties
    • mlTrainingRunProperties: ML-specific properties including:
      • Training metrics (accuracy, loss, F1 score, etc.)
      • Hyperparameters (learning rate, batch size, epochs, etc.)
      • Output URLs (model artifacts, checkpoints)
      • External URLs (links to training dashboards)
    • dataProcessInstanceInput: Input datasets used for training
    • dataProcessInstanceOutput: Output datasets (predictions, feature importance, etc.)
    • dataProcessInstanceRunEvent: Start, completion, and failure events

Training runs create lineage relationships showing:

  • Upstream: Which datasets were used for training
  • Downstream: Which models were produced by the training run

Models reference their training runs through the trainingJobs field in mlModelProperties, and model groups can also reference training runs to track all training activity for a model family.

Experiments

Experiments organize related training runs into logical groups, typically representing a series of attempts to optimize a model or compare different approaches. In DataHub, experiments are modeled as container entities:

  • Entity Type: container
  • Subtype: MLAssetSubTypes.MLFLOW_EXPERIMENT
  • Purpose: Group related training runs for organization and comparison

Training runs belong to experiments through the container aspect, creating a hierarchy:

Experiment: "Customer Churn Prediction"
├── Training Run 1: baseline model
├── Training Run 2: with feature engineering
├── Training Run 3: hyperparameter tuning
└── Training Run 4: final production model

This structure mirrors common ML platform patterns (like MLflow's experiment/run hierarchy) and enables:

  • Comparing metrics across multiple training attempts
  • Tracking the evolution of a model through iterations
  • Understanding which approaches were tried and their results
  • Organizing training work by project or objective
Python SDK: Create training runs and experiments
import argparse
from datetime import datetime

from dh_ai_client import DatahubAIClient

from datahub.emitter.mcp_builder import (
ContainerKey,
)
from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
from datahub.metadata.schema_classes import (
AuditStampClass,
DataProcessInstancePropertiesClass,
MLHyperParamClass,
MLMetricClass,
MLTrainingRunPropertiesClass,
)
from datahub.metadata.urns import (
CorpUserUrn,
DataProcessInstanceUrn,
GlossaryTermUrn,
TagUrn,
)
from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset
from datahub.sdk.mlmodel import MLModel
from datahub.sdk.mlmodelgroup import MLModelGroup

parser = argparse.ArgumentParser()
parser.add_argument("--token", required=False, help="DataHub access token")
parser.add_argument(
"--server_url",
required=False,
default="http://localhost:8080",
help="DataHub server URL (defaults to http://localhost:8080)",
)
args = parser.parse_args()

# Initialize client
client = DatahubAIClient(token=args.token, server_url=args.server_url)

# Use a unique prefix for all IDs to avoid conflicts
prefix = "test"

# Define all entity IDs upfront
# Basic entity IDs
basic_model_group_id = f"{prefix}_basic_group"
basic_model_id = f"{prefix}_basic_model"
basic_experiment_id = f"{prefix}_basic_experiment"
basic_run_id = f"{prefix}_basic_run"
basic_dataset_id = f"{prefix}_basic_dataset"

# Advanced entity IDs
advanced_model_group_id = f"{prefix}_airline_forecast_models_group"
advanced_model_id = f"{prefix}_arima_model"
advanced_experiment_id = f"{prefix}_airline_forecast_experiment"
advanced_run_id = f"{prefix}_simple_training_run"
advanced_input_dataset_id = f"{prefix}_iris_input"
advanced_output_dataset_id = f"{prefix}_iris_output"

# Display names with prefix
basic_model_group_name = f"{prefix} Basic Group"
basic_model_name = f"{prefix} Basic Model"
basic_experiment_name = f"{prefix} Basic Experiment"
basic_run_name = f"{prefix} Basic Run"
basic_dataset_name = f"{prefix} Basic Dataset"

advanced_model_group_name = f"{prefix} Airline Forecast Models Group"
advanced_model_name = f"{prefix} ARIMA Model"
advanced_experiment_name = f"{prefix} Airline Forecast Experiment"
advanced_run_name = f"{prefix} Simple Training Run"
advanced_input_dataset_name = f"{prefix} Iris Training Input Data"
advanced_output_dataset_name = f"{prefix} Iris Model Output Data"


def create_basic_model_group():
"""Create a basic model group."""
print("Creating basic model group...")
basic_model_group = MLModelGroup(
id=basic_model_group_id,
platform="mlflow",
name=basic_model_group_name,
)
client._emit_mcps(basic_model_group.as_mcps())
return basic_model_group


def create_advanced_model_group():
"""Create an advanced model group."""
print("Creating advanced model group...")
advanced_model_group = MLModelGroup(
id=advanced_model_group_id,
platform="mlflow",
name=advanced_model_group_name,
description="Group of models for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
)
client._emit_mcps(advanced_model_group.as_mcps())
return advanced_model_group


def create_basic_model():
"""Create a basic model."""
print("Creating basic model...")
basic_model = MLModel(
id=basic_model_id,
platform="mlflow",
name=basic_model_name,
)
client._emit_mcps(basic_model.as_mcps())
return basic_model


def create_advanced_model():
"""Create an advanced model."""
print("Creating advanced model...")
advanced_model = MLModel(
id=advanced_model_id,
platform="mlflow",
name=advanced_model_name,
description="ARIMA model for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
version="1",
aliases=["champion"],
hyper_params={"learning_rate": "0.01"},
training_metrics={"accuracy": "0.9"},
)
client._emit_mcps(advanced_model.as_mcps())
return advanced_model


def create_basic_experiment():
"""Create a basic experiment."""
print("Creating basic experiment...")
basic_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=basic_experiment_id),
display_name=basic_experiment_name,
)
client._emit_mcps(basic_experiment.as_mcps())
return basic_experiment


def create_advanced_experiment():
"""Create an advanced experiment."""
print("Creating advanced experiment...")
advanced_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
display_name=advanced_experiment_name,
description="Experiment to forecast airline passenger numbers",
extra_properties={"team": "forecasting"},
created=datetime(2025, 4, 9, 22, 30),
last_modified=datetime(2025, 4, 9, 22, 30),
subtype=MLAssetSubTypes.MLFLOW_EXPERIMENT,
)
client._emit_mcps(advanced_experiment.as_mcps())
return advanced_experiment


def create_basic_training_run():
"""Create a basic training run."""
print("Creating basic training run...")
basic_run_urn = client.create_training_run(
run_id=basic_run_id,
run_name=basic_run_name,
)
return basic_run_urn


def create_advanced_training_run():
"""Create an advanced training run."""
print("Creating advanced training run...")
advanced_run_urn = client.create_training_run(
run_id=advanced_run_id,
properties=DataProcessInstancePropertiesClass(
name=advanced_run_name,
created=AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=MLTrainingRunPropertiesClass(
id=advanced_run_id,
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[MLHyperParamClass(name="learning_rate", value="0.01")],
externalUrl="https:localhost:5000",
),
run_result=RunResultType.FAILURE,
start_timestamp=1628580000000,
end_timestamp=1628580001000,
)
return advanced_run_urn


def create_basic_dataset():
"""Create a basic dataset."""
print("Creating basic dataset...")
basic_input_dataset = Dataset(
platform="snowflake",
name=basic_dataset_id,
display_name=basic_dataset_name,
)
client._emit_mcps(basic_input_dataset.as_mcps())
return basic_input_dataset


def create_advanced_datasets():
"""Create advanced datasets."""
print("Creating advanced datasets...")
advanced_input_dataset = Dataset(
platform="snowflake",
name=advanced_input_dataset_id,
description="Raw Iris dataset used for training ML models",
schema=[("id", "number"), ("name", "string"), ("species", "string")],
display_name=advanced_input_dataset_name,
tags=["urn:li:tag:ml_data", "urn:li:tag:iris"],
terms=["urn:li:glossaryTerm:raw_data"],
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
custom_properties={
"data_source": "UCI Repository",
"records": "150",
"features": "4",
},
)
client._emit_mcps(advanced_input_dataset.as_mcps())

advanced_output_dataset = Dataset(
platform="snowflake",
name=advanced_output_dataset_id,
description="Processed Iris dataset with model predictions",
schema=[("id", "number"), ("name", "string"), ("species", "string")],
display_name=advanced_output_dataset_name,
tags=["urn:li:tag:ml_data", "urn:li:tag:predictions"],
terms=["urn:li:glossaryTerm:model_output"],
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
custom_properties={
"model_version": "1.0",
"records": "150",
"accuracy": "0.95",
},
)
client._emit_mcps(advanced_output_dataset.as_mcps())
return advanced_input_dataset, advanced_output_dataset


# Split relationship functions into individual top-level functions
def add_model_to_model_group(model, model_group):
"""Add model to model group relationship."""
print("Adding model to model group...")
model.set_model_group(model_group.urn)
client._emit_mcps(model.as_mcps())


def add_run_to_experiment(run_urn, experiment):
"""Add run to experiment relationship."""
print("Adding run to experiment...")
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=str(experiment.urn))


def add_run_to_model(model, run_id):
"""Add run to model relationship."""
print("Adding run to model...")
model.add_training_job(DataProcessInstanceUrn(run_id))
client._emit_mcps(model.as_mcps())


def add_run_to_model_group(model_group, run_id):
"""Add run to model group relationship."""
print("Adding run to model group...")
model_group.add_training_job(DataProcessInstanceUrn(run_id))
client._emit_mcps(model_group.as_mcps())


def add_input_dataset_to_run(run_urn, input_dataset):
"""Add input dataset to run relationship."""
print("Adding input dataset to run...")
client.add_input_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(input_dataset.urn)]
)


def add_output_dataset_to_run(run_urn, output_dataset):
"""Add output dataset to run relationship."""
print("Adding output dataset to run...")
client.add_output_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(output_dataset.urn)]
)


def update_model_properties(model):
"""Update model properties."""
print("Updating model properties...")

# Update model version
model.set_version("2")

# Add tags and terms
model.add_tag(TagUrn("marketing"))
model.add_term(GlossaryTermUrn("marketing"))

# Add version alias
model.add_version_alias("challenger")

# Save the changes
client._emit_mcps(model.as_mcps())


def update_model_group_properties(model_group):
"""Update model group properties."""
print("Updating model group properties...")

# Update description
model_group.set_description("Updated description for airline forecast models")

# Add tags and terms
model_group.add_tag(TagUrn("production"))
model_group.add_term(GlossaryTermUrn("time-series"))

# Update custom properties
model_group.set_custom_properties(
{"team": "forecasting", "business_unit": "operations", "status": "active"}
)

# Save the changes
client._emit_mcps(model_group.as_mcps())


def update_experiment_properties():
"""Update experiment properties."""
print("Updating experiment properties...")

# Create a container object for the existing experiment
existing_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
display_name=advanced_experiment_name,
)

# Update properties
existing_experiment.set_description(
"Updated experiment for forecasting passenger numbers"
)
existing_experiment.add_tag(TagUrn("time-series"))
existing_experiment.add_term(GlossaryTermUrn("forecasting"))
existing_experiment.set_custom_properties(
{"team": "forecasting", "priority": "high", "status": "active"}
)

# Save the changes
client._emit_mcps(existing_experiment.as_mcps())


def main():
# Parse arguments
print("Creating AI assets...")

# Comment in/out the functions you want to run
# Create basic entities
create_basic_model_group()
create_basic_model()
create_basic_experiment()
create_basic_training_run()
create_basic_dataset()

# Create advanced entities
advanced_model_group = create_advanced_model_group()
advanced_model = create_advanced_model()
advanced_experiment = create_advanced_experiment()
advanced_run_urn = create_advanced_training_run()
advanced_input_dataset, advanced_output_dataset = create_advanced_datasets()

# # Create relationships - each can be commented out independently
add_model_to_model_group(advanced_model, advanced_model_group)
add_run_to_experiment(advanced_run_urn, advanced_experiment)
add_run_to_model(advanced_model, advanced_run_id)
add_run_to_model_group(advanced_model_group, advanced_run_id)
add_input_dataset_to_run(advanced_run_urn, advanced_input_dataset)
add_output_dataset_to_run(advanced_run_urn, advanced_output_dataset)

# # Update properties - each can be commented out independently
update_model_properties(advanced_model)
update_model_group_properties(advanced_model_group)
update_experiment_properties()

print("All done! AI entities created successfully.")


if __name__ == "__main__":
main()

Relationships and Lineage

ML Models support rich relationship modeling through various aspects and fields:

Core Relationships

  • Model Groups (via groups field in mlModelProperties): Models can belong to mlModelGroup entities, creating a MemberOf relationship. This organizes related models into logical families or collections.

  • Training Runs (via trainingJobs field in mlModelProperties): Models reference dataProcessInstance entities with MLFLOW_TRAINING_RUN subtype that produced them. This creates upstream lineage showing:

    • Which training run created this model
    • What datasets were used for training (via the training run's input datasets)
    • What hyperparameters and metrics were recorded
    • Which experiment the training run belonged to
  • Features (via mlFeatures field in mlModelProperties): Models can consume mlFeature entities, creating a Consumes relationship. This documents:

    • Which features are required for model inference
    • The complete feature set used during training
    • Dependencies on feature stores or feature tables
  • Deployments (via deployments field in mlModelProperties): Models can be deployed to mlModelDeployment entities, representing running model endpoints in various environments (production, staging, etc.)

  • Training Datasets (via mlModelTrainingData aspect): Direct references to datasets used for training, including preprocessing information and motivation for dataset selection

  • Evaluation Datasets (via mlModelEvaluationData aspect): References to datasets used for model evaluation and testing

Lineage Graph Structure

These relationships create a comprehensive lineage graph:

Training Datasets → Training Run → ML Model → ML Model Deployment

Experiment

Feature Tables → ML Features → ML Model

ML Model Group ← ML Model

This enables powerful queries such as:

  • "Show me all datasets that influenced this model's predictions"
  • "Which models will be affected if this dataset schema changes?"
  • "What's the full history of training runs that created versions of this model?"
  • "Which production endpoints are serving this model?"
Python SDK: Update model-specific aspects
import datahub.metadata.schema_classes as models
from datahub.metadata.urns import DatasetUrn, MlModelUrn
from datahub.sdk import DataHubClient

client = DataHubClient.from_env()

model_urn = MlModelUrn(platform="mlflow", name="customer-churn-predictor")

mlmodel = client.entities.get(model_urn)

intended_use = models.IntendedUseClass(
primaryUses=[
"Predict customer churn to enable proactive retention campaigns",
"Identify high-risk customers for targeted interventions",
],
primaryUsers=[models.IntendedUserTypeClass.ENTERPRISE],
outOfScopeUses=[
"Not suitable for real-time predictions (batch inference only)",
"Not trained on international markets outside North America",
],
)

mlmodel._set_aspect(intended_use)

training_data = models.TrainingDataClass(
trainingData=[
models.BaseDataClass(
dataset=str(
DatasetUrn(
platform="snowflake", name="prod.analytics.customer_features"
)
),
motivation="Historical customer data with confirmed churn labels",
preProcessing=[
"Removed customers with less than 30 days of history",
"Standardized numerical features using StandardScaler",
"One-hot encoded categorical variables",
],
)
]
)

mlmodel._set_aspect(training_data)

source_code = models.SourceCodeClass(
sourceCode=[
models.SourceCodeUrlClass(
type=models.SourceCodeUrlTypeClass.ML_MODEL_SOURCE_CODE,
sourceCodeUrl="https://github.com/example/ml-models/tree/main/churn-predictor",
)
]
)

mlmodel._set_aspect(source_code)

ethical_considerations = models.EthicalConsiderationsClass(
data=["Model uses demographic data (age, location) which may be sensitive"],
risksAndHarms=[
"Predictions may disproportionately affect certain customer segments",
"False positives could lead to unnecessary retention spending",
],
mitigations=[
"Regular bias audits conducted quarterly",
"Human review required for high-value customer interventions",
],
)

mlmodel._set_aspect(ethical_considerations)

client.entities.update(mlmodel)

print(f"Updated aspects for model: {model_urn}")

Tags, Terms, and Ownership

Like other DataHub entities, ML Models support:

  • Tags (globalTags aspect): Flexible categorization (e.g., "pii-model", "production-ready", "experimental")
  • Glossary Terms (glossaryTerms aspect): Business concepts (e.g., "Customer Churn", "Fraud Detection")
  • Ownership (ownership aspect): Individuals or teams responsible for the model (data scientists, ML engineers, etc.)
  • Domains (domains aspect): Organizational grouping (e.g., "Recommendations", "Risk Management")

Complete ML Workflow Example

The following example demonstrates a complete ML model lifecycle in DataHub, showing how all the pieces work together:

1. Create Model Group

2. Create Experiment (Container)

3. Create Training Run (DataProcessInstance)
├── Link input datasets
├── Link output datasets
└── Add metrics and hyperparameters

4. Create Model
├── Set version and aliases
├── Link to model group
├── Link to training run
├── Add hyperparameters and metrics
└── Add ownership and tags

5. Link Training Run to Experiment

6. Update Model properties as needed
├── Change version aliases (champion → challenger)
├── Add additional tags/terms
└── Update metrics from production

This workflow creates rich lineage showing:

  • Which datasets trained the model
  • What experiments and training runs were involved
  • How the model evolved through versions
  • Which version is deployed (via aliases)
  • Who owns and maintains the model
Complete Python Example: Full ML Workflow

See the comprehensive example in /metadata-ingestion/examples/ai/dh_ai_docs_demo.py which demonstrates:

  • Creating model groups with metadata
  • Creating experiments to organize training runs
  • Creating training runs with metrics, hyperparameters, and dataset lineage
  • Creating models with versions and aliases
  • Linking all entities together to form complete lineage
  • Updating properties and managing the model lifecycle

The example shows both basic patterns for getting started and advanced patterns for production ML systems.

Code Examples

Querying ML Model Information

The standard REST APIs can be used to retrieve ML Model entities and their aspects:

Python: Query an ML Model via REST API
import urllib.parse

import requests

gms_server = "http://localhost:8080"

model_urn = "urn:li:mlModel:(urn:li:dataPlatform:mlflow,customer-churn-predictor,PROD)"
encoded_urn = urllib.parse.quote(model_urn, safe="")

response = requests.get(f"{gms_server}/entities/{encoded_urn}")

if response.status_code == 200:
entity = response.json()

print(f"Entity URN: {entity['urn']}")
print("\nAspects:")

if "mlModelProperties" in entity["aspects"]:
props = entity["aspects"]["mlModelProperties"]
print(f" Name: {props.get('name')}")
print(f" Description: {props.get('description')}")
print(f" Type: {props.get('type')}")

if props.get("hyperParams"):
print("\n Hyperparameters:")
for param in props["hyperParams"]:
print(f" - {param['name']}: {param['value']}")

if props.get("trainingMetrics"):
print("\n Training Metrics:")
for metric in props["trainingMetrics"]:
print(f" - {metric['name']}: {metric['value']}")

if "globalTags" in entity["aspects"]:
tags = entity["aspects"]["globalTags"]["tags"]
print(f"\n Tags: {[tag['tag'] for tag in tags]}")

if "ownership" in entity["aspects"]:
owners = entity["aspects"]["ownership"]["owners"]
print(f"\n Owners: {[owner['owner'] for owner in owners]}")

if "intendedUse" in entity["aspects"]:
intended = entity["aspects"]["intendedUse"]
print(f"\n Primary Uses: {intended.get('primaryUses')}")
print(f" Out of Scope Uses: {intended.get('outOfScopeUses')}")

else:
print(f"Failed to fetch entity: {response.status_code}")
print(response.text)

Integration Points

ML Models integrate with several other entities in the DataHub metadata model:

  • mlModelGroup: Logical grouping of related model versions (e.g., all versions of a recommendation model)
  • mlModelDeployment: Running instances of deployed models with status, endpoint URLs, and deployment metadata
  • mlFeature: Individual features consumed by the model for inference
  • mlFeatureTable: Collections of features, often from feature stores
  • dataset: Training and evaluation datasets used by the model
  • dataProcessInstance (with MLFLOW_TRAINING_RUN subtype): Specific training runs that created model versions, including metrics, hyperparameters, and lineage to input/output datasets
  • container (with MLFLOW_EXPERIMENT subtype): Experiments that organize related training runs for a model or project
  • versionSet: Groups all versions of a model together for version management

GraphQL Resolvers

The GraphQL API provides rich querying capabilities for ML Models through resolvers in datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/mlmodel/. These resolvers support:

  • Fetching model details with all aspects
  • Navigating relationships to features, groups, and deployments
  • Searching and filtering models by tags, terms, platform, etc.

Ingestion Sources

Several ingestion sources automatically extract ML Model metadata:

  • MLflow: Extracts registered models, versions, metrics, parameters, and lineage from MLflow tracking servers
  • SageMaker: Ingests models, model packages, and endpoints from AWS SageMaker
  • Vertex AI: Extracts models and endpoints from Google Cloud Vertex AI
  • Databricks: Ingests MLflow models from Databricks workspaces
  • Unity Catalog: Extracts ML models registered in Unity Catalog

These sources are located in /metadata-ingestion/src/datahub/ingestion/source/ and automatically populate model properties, relationships, and lineage.

Notable Exceptions

Model Versioning

ML Model versioning in DataHub uses the versionProperties aspect, which provides a robust framework for tracking model versions across their lifecycle. This is the standard approach demonstrated in production ML platforms.

Version Properties Aspect

Every ML Model should use the versionProperties aspect, which includes:

  • version: A VersionTagClass containing the version identifier (e.g., "1", "2", "v1.0.0")
  • versionSet: A URN that groups all versions of a model together (e.g., urn:li:versionSet:(mlModel,mlmodel_my-model_versions))
  • sortId: A string used for ordering versions (typically the version number zero-padded)
  • aliases: Optional array of VersionTagClass objects for named version references

Version Aliases for A/B Testing

Version aliases enable flexible model lifecycle management and A/B testing workflows. Common aliases include:

  • "champion": The currently deployed production model
  • "challenger": A candidate model being tested or evaluated
  • "baseline": A reference model for performance comparison
  • "latest": The most recently trained version

These aliases allow you to reference models by their role rather than specific version numbers, enabling smooth model promotion workflows:

Model v1 (alias: "champion")     # Currently in production
Model v2 (alias: "challenger") # Being tested in canary deployment
Model v3 (alias: "latest") # Just completed training

When v2 proves superior, you can update aliases without changing infrastructure:

Model v1 (no alias)              # Retired
Model v2 (alias: "champion") # Promoted to production
Model v3 (alias: "challenger") # Now being tested

Model Groups and Versioning

Model groups (mlModelGroup entities) serve as logical containers for organizing related models. While model groups can contain multiple versions of the same model, versioning is handled through the versionProperties aspect on individual models, not through the group structure itself. Model groups are used for:

  • Organizing all versions of a model family
  • Grouping experimental variants or different architectures solving the same problem
  • Managing lineage and metadata common across multiple related models

The relationship between models and model groups is through the groups field in mlModelProperties, creating a MemberOf relationship.

Platform-Specific Naming

Different ML platforms have different naming conventions:

  • MLflow: Uses a two-level hierarchy (registered model name + version number). In DataHub, each version can be a separate entity, or versions can be tracked in a single entity.
  • SageMaker: Has multiple model concepts (model, model package, model package group). DataHub can model these as separate entities or consolidate them.
  • Vertex AI: Uses fully qualified resource names. These should be simplified to human-readable names when possible.

When ingesting from these platforms, connectors handle platform-specific naming and convert it to appropriate DataHub URNs.

Model Cards

The various aspects (intendedUse, mlModelFactorPrompts, mlModelEthicalConsiderations, etc.) follow the Model Cards for Model Reporting framework (Mitchell et al., 2019). While these aspects are optional, they are strongly recommended for production models to ensure responsible AI practices and transparent model documentation.

Technical Reference

For technical details about fields, searchability, and relationships, view the Columns tab in DataHub.