Skip to content

Python API

TrainCraft exposes a curated public API via import traincraft. Every stage that the CLI runs is available as a standalone Python function.


Top-level functions

import traincraft as tc

# Config
cfg  = tc.load_config("my_run.toml")   # → TrainCraftConfig
cfg2 = tc.loads_config(toml_string)    # → TrainCraftConfig

# Pipeline stages (pure functions — same ones the CLI calls)
structure  = tc.build_geometry(cfg.geometry)
calc       = tc.make_calculator(cfg.calculator)
frames     = tc.run_sampling(structure, calc, job, cfg.sampling)
selected   = tc.run_funnel(frames, cfg.selection)
summary    = tc.run_pipeline(cfg)      # the whole pipeline

# Dataset IO
tc.write_frames("out.extxyz", frames)
frames = tc.read_frames("dataset.extxyz")

Structure

Structure dataclass

Source code in src/traincraft/core/structure.py
@dataclass
class Structure:
    atoms: Atoms
    properties: dict[str, Any] = field(default_factory=dict)
    provenance: Provenance = field(default_factory=Provenance)

    @property
    def hash(self) -> str:
        """Content hash from composition + geometry (rounded for stability)."""
        a = self.atoms
        payload = {
            "numbers": a.get_atomic_numbers().tolist(),
            "positions": np.round(a.get_positions(), 4).tolist(),
            "cell": np.round(np.asarray(a.get_cell()), 4).tolist(),
            "pbc": a.get_pbc().tolist(),
        }
        blob = json.dumps(payload, sort_keys=True).encode()
        return hashlib.sha1(blob).hexdigest()[:16]

    @classmethod
    def from_ase(cls, atoms: Atoms, **kwargs: Any) -> Structure:
        return cls(atoms=atoms.copy(), **kwargs)

    def to_ase(self, with_properties: bool = True) -> Atoms:
        """Return a copy of the atoms with properties/provenance in ``info``."""
        atoms = self.atoms.copy()
        atoms.info["tc_provenance"] = self.provenance.to_dict()
        atoms.info["tc_hash"] = self.hash
        if with_properties:
            for key, value in self.properties.items():
                if key == "forces" and value is not None:
                    atoms.arrays["tc_forces"] = np.asarray(value)
                else:
                    atoms.info[f"tc_{key}"] = value
        return atoms

    def copy(self) -> Structure:
        return Structure(
            atoms=self.atoms.copy(),
            properties=dict(self.properties),
            provenance=Provenance.from_dict(self.provenance.to_dict()),
        )

    # --- fragment identity helpers ----------------------------------------
    @property
    def fragments(self):
        """Per-atom fragment array, or None if unset."""
        from .fragments import get_fragments
        return get_fragments(self.atoms)

    def set_fragments(self, frag) -> None:
        """Attach/overwrite the per-atom fragment array."""
        from .fragments import set_fragments
        set_fragments(self.atoms, frag)

    @property
    def n_fragments(self) -> int:
        """Number of distinct mobile fragments (excludes framework atoms)."""
        from .fragments import fragment_ids
        return len(fragment_ids(self.atoms))

    # --- interop (see core.converter) -------------------------------------
    def to_pymatgen(self):
        """Return a pymatgen ``Structure`` (periodic) or ``Molecule``."""
        from .converter import ase_to_pymatgen
        return ase_to_pymatgen(self.atoms)

    def to_rdkit(self, charge: int = 0):
        """Return an RDKit ``Mol`` with bonds perceived (non-periodic only)."""
        from .converter import ase_to_rdkit
        return ase_to_rdkit(self.atoms, charge=charge)

    @classmethod
    def from_pymatgen(cls, obj, **kwargs: Any) -> Structure:
        """Build a :class:`Structure` from a pymatgen ``Structure``/``Molecule``."""
        from .converter import pymatgen_to_ase
        return cls.from_ase(pymatgen_to_ase(obj), **kwargs)

    @classmethod
    def from_rdkit(cls, mol, conf_id: int = 0, **kwargs: Any) -> Structure:
        """Build a :class:`Structure` from one conformer of an RDKit ``Mol``."""
        from .converter import rdkit_to_ase
        return cls.from_ase(rdkit_to_ase(mol, conf_id=conf_id), **kwargs)

hash property

hash: str

Content hash from composition + geometry (rounded for stability).

fragments property

fragments

Per-atom fragment array, or None if unset.

n_fragments property

n_fragments: int

Number of distinct mobile fragments (excludes framework atoms).

from_ase classmethod

from_ase(atoms: Atoms, **kwargs: Any) -> Structure
Source code in src/traincraft/core/structure.py
@classmethod
def from_ase(cls, atoms: Atoms, **kwargs: Any) -> Structure:
    return cls(atoms=atoms.copy(), **kwargs)

to_ase

to_ase(with_properties: bool = True) -> Atoms

Return a copy of the atoms with properties/provenance in info.

Source code in src/traincraft/core/structure.py
def to_ase(self, with_properties: bool = True) -> Atoms:
    """Return a copy of the atoms with properties/provenance in ``info``."""
    atoms = self.atoms.copy()
    atoms.info["tc_provenance"] = self.provenance.to_dict()
    atoms.info["tc_hash"] = self.hash
    if with_properties:
        for key, value in self.properties.items():
            if key == "forces" and value is not None:
                atoms.arrays["tc_forces"] = np.asarray(value)
            else:
                atoms.info[f"tc_{key}"] = value
    return atoms

copy

copy() -> Structure
Source code in src/traincraft/core/structure.py
def copy(self) -> Structure:
    return Structure(
        atoms=self.atoms.copy(),
        properties=dict(self.properties),
        provenance=Provenance.from_dict(self.provenance.to_dict()),
    )

set_fragments

set_fragments(frag) -> None

Attach/overwrite the per-atom fragment array.

Source code in src/traincraft/core/structure.py
def set_fragments(self, frag) -> None:
    """Attach/overwrite the per-atom fragment array."""
    from .fragments import set_fragments
    set_fragments(self.atoms, frag)

to_pymatgen

to_pymatgen()

Return a pymatgen Structure (periodic) or Molecule.

Source code in src/traincraft/core/structure.py
def to_pymatgen(self):
    """Return a pymatgen ``Structure`` (periodic) or ``Molecule``."""
    from .converter import ase_to_pymatgen
    return ase_to_pymatgen(self.atoms)

from_pymatgen classmethod

from_pymatgen(obj, **kwargs: Any) -> Structure

Build a :class:Structure from a pymatgen Structure/Molecule.

Source code in src/traincraft/core/structure.py
@classmethod
def from_pymatgen(cls, obj, **kwargs: Any) -> Structure:
    """Build a :class:`Structure` from a pymatgen ``Structure``/``Molecule``."""
    from .converter import pymatgen_to_ase
    return cls.from_ase(pymatgen_to_ase(obj), **kwargs)

to_rdkit

to_rdkit(charge: int = 0)

Return an RDKit Mol with bonds perceived (non-periodic only).

Source code in src/traincraft/core/structure.py
def to_rdkit(self, charge: int = 0):
    """Return an RDKit ``Mol`` with bonds perceived (non-periodic only)."""
    from .converter import ase_to_rdkit
    return ase_to_rdkit(self.atoms, charge=charge)

from_rdkit classmethod

from_rdkit(mol, conf_id: int = 0, **kwargs: Any) -> Structure

Build a :class:Structure from one conformer of an RDKit Mol.

Source code in src/traincraft/core/structure.py
@classmethod
def from_rdkit(cls, mol, conf_id: int = 0, **kwargs: Any) -> Structure:
    """Build a :class:`Structure` from one conformer of an RDKit ``Mol``."""
    from .converter import rdkit_to_ase
    return cls.from_ase(rdkit_to_ase(mol, conf_id=conf_id), **kwargs)

Provenance

Provenance dataclass

Source code in src/traincraft/core/provenance.py
@dataclass
class Provenance:
    origin: str = "generated"
    source: str | None = None  # e.g. "builder:nanotube", "source:file"
    transforms: list[str] = field(default_factory=list)
    calculator: str | None = None  # method that produced ``properties``
    level_of_theory: dict[str, Any] = field(default_factory=dict)
    seed: int | None = None
    parents: list[str] = field(default_factory=list)  # parent structure hashes
    extra: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        if self.origin not in ORIGINS:
            raise ValueError(f"origin must be one of {ORIGINS}, got {self.origin!r}")

    def to_dict(self) -> dict[str, Any]:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> Provenance:
        known = {f for f in cls.__dataclass_fields__}  # noqa: C416
        return cls(**{k: v for k, v in data.items() if k in known})

to_dict

to_dict() -> dict[str, Any]
Source code in src/traincraft/core/provenance.py
def to_dict(self) -> dict[str, Any]:
    return asdict(self)

from_dict classmethod

from_dict(data: dict[str, Any]) -> Provenance
Source code in src/traincraft/core/provenance.py
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Provenance:
    known = {f for f in cls.__dataclass_fields__}  # noqa: C416
    return cls(**{k: v for k, v in data.items() if k in known})

Workspace and Job

Workspace

Owns an absolute run directory and hands out sub-directories/jobs.

Source code in src/traincraft/core/workspace.py
class Workspace:
    """Owns an absolute run directory and hands out sub-directories/jobs."""

    def __init__(self, root: str | Path):
        self.root = Path(root).resolve()
        self.root.mkdir(parents=True, exist_ok=True)

    def subdir(self, *parts: str) -> Path:
        p = self.root.joinpath(*parts)
        p.mkdir(parents=True, exist_ok=True)
        return p

    def job(self, *parts: str) -> Job:
        return Job(dir=self.subdir(*parts))

subdir

subdir(*parts: str) -> Path
Source code in src/traincraft/core/workspace.py
def subdir(self, *parts: str) -> Path:
    p = self.root.joinpath(*parts)
    p.mkdir(parents=True, exist_ok=True)
    return p

job

job(*parts: str) -> Job
Source code in src/traincraft/core/workspace.py
def job(self, *parts: str) -> Job:
    return Job(dir=self.subdir(*parts))

Job dataclass

Source code in src/traincraft/core/workspace.py
@dataclass
class Job:
    dir: Path

    @property
    def marker(self) -> Path:
        return self.dir / ".tc_done"

    def done(self) -> bool:
        return self.marker.exists()

    def mark_done(self) -> None:
        self.marker.write_text("ok\n")

    def path(self, *parts: str) -> Path:
        return self.dir.joinpath(*parts)

path

path(*parts: str) -> Path
Source code in src/traincraft/core/workspace.py
def path(self, *parts: str) -> Path:
    return self.dir.joinpath(*parts)

Geometry

build_geometry

build_geometry(geom_cfg) -> Structure

Resolve a :class:GeometryConfig into a single :class:Structure.

Source code in src/traincraft/geometry/__init__.py
def build_geometry(geom_cfg) -> Structure:
    """Resolve a :class:`GeometryConfig` into a single :class:`Structure`."""
    if geom_cfg.builder is not None:
        structure = build_builder(geom_cfg.builder)
    else:
        structure = build_source(geom_cfg.source)
    for transform_cfg in geom_cfg.transforms:
        structure = apply_transform(structure, transform_cfg)
    return structure

Converter

ase_to_pymatgen

ase_to_pymatgen(atoms: Atoms) -> Structure | Molecule

Convert ASE Atoms to a pymatgen Structure (periodic) or Molecule.

The choice is driven by periodicity: an Atoms periodic in all three directions becomes a Structure; otherwise a Molecule (the cell is dropped, since a partially periodic slab/wire has no pymatgen analogue).

Source code in src/traincraft/core/converter.py
def ase_to_pymatgen(atoms: Atoms) -> Structure | Molecule:
    """Convert ASE ``Atoms`` to a pymatgen ``Structure`` (periodic) or ``Molecule``.

    The choice is driven by periodicity: an ``Atoms`` periodic in all three
    directions becomes a ``Structure``; otherwise a ``Molecule`` (the cell is
    dropped, since a partially periodic slab/wire has no pymatgen analogue).
    """
    adaptor = _require_pymatgen()
    if bool(np.all(atoms.get_pbc())):
        return adaptor.get_structure(atoms)
    return adaptor.get_molecule(atoms)

pymatgen_to_ase

pymatgen_to_ase(obj: Structure | Molecule) -> Atoms

Convert a pymatgen Structure or Molecule to ASE Atoms.

Source code in src/traincraft/core/converter.py
def pymatgen_to_ase(obj: Structure | Molecule) -> Atoms:
    """Convert a pymatgen ``Structure`` or ``Molecule`` to ASE ``Atoms``."""
    adaptor = _require_pymatgen()
    return adaptor.get_atoms(obj)

ase_to_rdkit

ase_to_rdkit(atoms: Atoms, charge: int = 0) -> Mol

Convert non-periodic ASE Atoms to an RDKit Mol with bonds perceived.

Bonds are inferred from the 3D geometry by RDKit's DetermineBonds (the xyz2mol algorithm). Raises if the structure is periodic in any direction.

Source code in src/traincraft/core/converter.py
def ase_to_rdkit(atoms: Atoms, charge: int = 0) -> Mol:
    """Convert non-periodic ASE ``Atoms`` to an RDKit ``Mol`` with bonds perceived.

    Bonds are inferred from the 3D geometry by RDKit's ``DetermineBonds`` (the
    xyz2mol algorithm).  Raises if the structure is periodic in any direction.
    """
    if bool(np.any(atoms.get_pbc())):
        raise ValueError(
            "ase_to_rdkit needs a non-periodic structure; got pbc="
            f"{atoms.get_pbc().tolist()}. RDKit molecules are not periodic."
        )
    Chem, rdDetermineBonds = _require_rdkit()

    buf = io.StringIO()
    write(buf, atoms, format="xyz")
    mol = Chem.MolFromXYZBlock(buf.getvalue())
    if mol is None:
        raise ValueError("RDKit could not parse the structure as an XYZ molecule")
    try:
        rdDetermineBonds.DetermineBonds(mol, charge=charge)
    except ValueError as e:
        raise ValueError(
            f"RDKit could not perceive bonds (charge={charge}): {e}. "
            "Try passing the correct total charge."
        ) from e
    return mol

rdkit_to_ase

rdkit_to_ase(mol: Mol, conf_id: int = 0) -> Atoms

Convert one conformer of an RDKit Mol to ASE Atoms.

conf_id selects which embedded conformer to read (default: the first).

Source code in src/traincraft/core/converter.py
def rdkit_to_ase(mol: Mol, conf_id: int = 0) -> Atoms:
    """Convert one conformer of an RDKit ``Mol`` to ASE ``Atoms``.

    ``conf_id`` selects which embedded conformer to read (default: the first).
    """
    if mol.GetNumConformers() == 0:
        raise ValueError(
            "RDKit molecule has no conformers; embed one first "
            "(e.g. AllChem.EmbedMolecule)."
        )
    conf = mol.GetConformer(conf_id)
    positions = np.asarray(conf.GetPositions())
    symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
    return Atoms(symbols=symbols, positions=positions)

Fragment helpers

get_fragments

get_fragments(atoms: Atoms) -> np.ndarray | None

Return the per-atom fragment array, or None if unset.

Source code in src/traincraft/core/fragments.py
def get_fragments(atoms: Atoms) -> np.ndarray | None:
    """Return the per-atom fragment array, or None if unset."""
    if FRAGMENT_KEY in atoms.arrays:
        return atoms.arrays[FRAGMENT_KEY].astype(int)
    return None

set_fragments

set_fragments(atoms: Atoms, frag: ndarray | list[int]) -> None

Attach/overwrite the per-atom fragment array (length must equal len(atoms)).

Source code in src/traincraft/core/fragments.py
def set_fragments(atoms: Atoms, frag: np.ndarray | list[int]) -> None:
    """Attach/overwrite the per-atom fragment array (length must equal len(atoms))."""
    frag = np.asarray(frag, dtype=int)
    if frag.shape != (len(atoms),):
        raise ValueError(f"fragment array must have shape ({len(atoms)},), got {frag.shape}")
    atoms.set_array(FRAGMENT_KEY, frag)

infer_fragments

infer_fragments(atoms: Atoms, scale: float = 1.2, framework_mask: ndarray | None = None) -> np.ndarray

Assign fragment ids by connected components of a covalent-radius graph.

Two atoms bond if distance < scale * (r_cov[i] + r_cov[j]). framework_mask (optional bool array, length == len(atoms)): atoms marked True are forced to FRAMEWORK (-1) and excluded from the connectivity graph. Returns the array; does NOT mutate atoms.

Source code in src/traincraft/core/fragments.py
def infer_fragments(
    atoms: Atoms,
    scale: float = 1.2,
    framework_mask: np.ndarray | None = None,
) -> np.ndarray:
    """Assign fragment ids by connected components of a covalent-radius graph.

    Two atoms bond if distance < scale * (r_cov[i] + r_cov[j]).
    `framework_mask` (optional bool array, length == len(atoms)): atoms marked
    True are forced to FRAMEWORK (-1) and excluded from the connectivity graph.
    Returns the array; does NOT mutate `atoms`.
    """
    from ase.neighborlist import NeighborList, natural_cutoffs
    from scipy.sparse import csr_matrix
    from scipy.sparse.csgraph import connected_components

    n = len(atoms)
    result = np.full(n, FRAMEWORK, dtype=int)

    # Determine which atoms are mobile (not in the framework mask).
    mobile = np.ones(n, dtype=bool)
    if framework_mask is not None:
        framework_mask = np.asarray(framework_mask, dtype=bool)
        if framework_mask.shape != (n,):
            raise ValueError(
                f"framework_mask must have shape ({n},), got {framework_mask.shape}"
            )
        mobile[framework_mask] = False

    mobile_idx = np.where(mobile)[0]
    if len(mobile_idx) == 0:
        return result

    cutoffs = natural_cutoffs(atoms, mult=scale)
    nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
    nl.update(atoms)

    # Build adjacency only among mobile atoms.
    idx_map = {int(i): j for j, i in enumerate(mobile_idx)}
    m = len(mobile_idx)
    rows, cols = [], []
    for local, global_i in enumerate(mobile_idx):
        neighbours, _ = nl.get_neighbors(global_i)
        for global_j in neighbours:
            if global_j in idx_map:
                rows.append(local)
                cols.append(idx_map[global_j])

    if rows:
        data = np.ones(len(rows), dtype=np.int8)
        adj = csr_matrix((data, (rows, cols)), shape=(m, m))
    else:
        adj = csr_matrix((m, m), dtype=np.int8)

    n_components, labels = connected_components(adj, directed=False)
    for local, global_i in enumerate(mobile_idx):
        result[global_i] = int(labels[local])

    return result

fragment_ids

fragment_ids(atoms: Atoms) -> list[int]

Sorted list of mobile fragment ids (excludes FRAMEWORK == -1).

Source code in src/traincraft/core/fragments.py
def fragment_ids(atoms: Atoms) -> list[int]:
    """Sorted list of mobile fragment ids (excludes FRAMEWORK == -1)."""
    frag = get_fragments(atoms)
    if frag is None:
        return []
    return sorted(int(i) for i in np.unique(frag) if i != FRAMEWORK)

fragment_mask

fragment_mask(atoms: Atoms, fid: int) -> np.ndarray

Boolean mask selecting atoms of fragment fid.

Source code in src/traincraft/core/fragments.py
def fragment_mask(atoms: Atoms, fid: int) -> np.ndarray:
    """Boolean mask selecting atoms of fragment `fid`."""
    frag = get_fragments(atoms)
    if frag is None:
        raise ValueError("no fragment array set on these atoms")
    return frag == fid

Registry

register

register(kind: str, name: str, *, capabilities: Iterable[str] | None = None)

Decorator: register obj under (kind, name).

Source code in src/traincraft/core/registry.py
def register(kind: str, name: str, *, capabilities: Iterable[str] | None = None):
    """Decorator: register ``obj`` under ``(kind, name)``."""

    def decorator(obj):
        bucket = _REGISTRY.setdefault(kind, {})
        if name in bucket:
            raise RegistryError(f"{kind} {name!r} is already registered")
        bucket[name] = {"obj": obj, "capabilities": set(capabilities or ())}
        return obj

    return decorator

get

get(kind: str, name: str)
Source code in src/traincraft/core/registry.py
def get(kind: str, name: str):
    try:
        return _REGISTRY[kind][name]["obj"]
    except KeyError:
        raise RegistryError(
            f"unknown {kind} {name!r}; available: {available(kind)}"
        ) from None

available

available(kind: str) -> list[str]
Source code in src/traincraft/core/registry.py
def available(kind: str) -> list[str]:
    return sorted(_REGISTRY.get(kind, {}))

Dataset

Dataset

Source code in src/traincraft/datasets/dataset.py
class Dataset:
    def __init__(self, path: str | Path):
        path = Path(path)
        if path.suffix != ".extxyz":
            path = path.with_suffix(".extxyz")
        self.path = path
        self._frames: list[Structure] = []
        self._hashes: set[str] = set()

    def append(self, structures: list[Structure]) -> int:
        """Add new frames, skipping exact duplicates. Returns count added."""
        added = 0
        for s in structures:
            h = s.hash
            if h in self._hashes:
                continue
            self._hashes.add(h)
            self._frames.append(s)
            added += 1
        return added

    def filter(self, origin: str | None = None) -> list[Structure]:
        if origin is None:
            return list(self._frames)
        return [s for s in self._frames if s.provenance.origin == origin]

    @property
    def frames(self) -> list[Structure]:
        return list(self._frames)

    def __len__(self) -> int:
        return len(self._frames)

    def write(self) -> Path:
        return write_frames(self.path, self._frames)

    @classmethod
    def load(cls, path: str | Path) -> Dataset:
        ds = cls(path)
        ds.append(read_frames(ds.path))
        return ds

append

append(structures: list[Structure]) -> int

Add new frames, skipping exact duplicates. Returns count added.

Source code in src/traincraft/datasets/dataset.py
def append(self, structures: list[Structure]) -> int:
    """Add new frames, skipping exact duplicates. Returns count added."""
    added = 0
    for s in structures:
        h = s.hash
        if h in self._hashes:
            continue
        self._hashes.add(h)
        self._frames.append(s)
        added += 1
    return added

filter

filter(origin: str | None = None) -> list[Structure]
Source code in src/traincraft/datasets/dataset.py
def filter(self, origin: str | None = None) -> list[Structure]:
    if origin is None:
        return list(self._frames)
    return [s for s in self._frames if s.provenance.origin == origin]

write

write() -> Path
Source code in src/traincraft/datasets/dataset.py
def write(self) -> Path:
    return write_frames(self.path, self._frames)

write_frames

write_frames(path: str | Path, structures: list[Structure], append: bool = False) -> Path
Source code in src/traincraft/datasets/io.py
def write_frames(path: str | Path, structures: list[Structure], append: bool = False) -> Path:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    images = [_to_atoms(s) for s in structures]
    write(str(path), images, format="extxyz", append=append)
    return path

read_frames

read_frames(path: str | Path) -> list[Structure]
Source code in src/traincraft/datasets/io.py
def read_frames(path: str | Path) -> list[Structure]:
    images = read(str(path), index=":", format="extxyz")
    if not isinstance(images, list):
        images = [images]
    out: list[Structure] = []
    for atoms in images:
        props: dict = {}
        if "tc_forces" in atoms.arrays:
            props["forces"] = np.asarray(atoms.arrays["tc_forces"])
        for key in list(atoms.info):
            if key.startswith("tc_") and key not in _META_KEYS:
                props[key[3:]] = atoms.info[key]
        raw = atoms.info.get("tc_provenance")
        prov = Provenance.from_dict(json.loads(raw)) if raw else Provenance()
        out.append(Structure(atoms=atoms, properties=props, provenance=prov))
    return out