diff --git a/README.md b/README.md index a052f83e559d15c353522f90f471e3e25c274bc1..5411221311771d8983eb2f17c1743bba037ff839 100644 --- a/README.md +++ b/README.md @@ -1 +1,7 @@ # Tensor Passing Network +## Install + +```shell +$ pip install git+ssh://git@git.nju.edu.cn/bigd4/tpn.git +``` +## Lammps diff --git a/interface/ase/ase.py b/interface/ase/ase.py index 92f2599fae15734551716f9ea3bd84c5900f4138..847b49fffbec132051e142becfea6c5b4e1cb35e 100644 --- a/interface/ase/ase.py +++ b/interface/ase/ase.py @@ -11,18 +11,19 @@ class MiaoCalculator(Calculator): "energies", "forces", "stress", + "dipole", + "polarizability", ] - def __init__(self, - cutoff : float, + def __init__(self, model_file : str="model.pt", device : str="cpu", **kwargs, ) -> None: Calculator.__init__(self, **kwargs) - self.cutoff = cutoff self.device = device - self.model = torch.load(model_file).to(device).double() + self.model = torch.load(model_file, map_location=device).double() + self.cutoff = float(self.model.cutoff.detach().cpu().numpy()) def calculate( self, @@ -35,22 +36,17 @@ class MiaoCalculator(Calculator): Calculator.calculate(self, atoms, properties, system_changes) idx_i, idx_j, offsets = neighbor_list("ijS", atoms, self.cutoff, self_interaction=False) - bonds = np.array([idx_i, idx_j]) - offsets = np.array(offsets) @ atoms.get_cell() - atomic_number = torch.tensor(atoms.numbers, dtype=torch.long, device=self.device) - edge_index = torch.tensor(bonds, dtype=torch.long, device=self.device) - offset = torch.tensor(offsets, dtype=torch.double, device=self.device) - coordinate = torch.tensor(atoms.positions, dtype=torch.double, device=self.device) - n_atoms = torch.tensor(len(atoms), dtype=torch.double, device=self.device) - batch = torch.zeros(len(atoms), dtype=torch.long, device=self.device) - + offset = np.array(offsets) @ atoms.get_cell() + data = { - "atomic_number" : atomic_number, - "edge_index" : edge_index, - "offset" : offset, - "coordinate" : coordinate, - "n_atoms" : n_atoms, - "batch" : batch, + "atomic_number": torch.tensor(atoms.numbers, dtype=torch.long, device=self.device), + "idx_i" : torch.tensor(idx_i, dtype=torch.long, device=self.device), + "idx_j" : torch.tensor(idx_j, dtype=torch.long, device=self.device), + "coordinate" : torch.tensor(atoms.positions, dtype=torch.double, device=self.device), + "n_atoms" : torch.tensor([len(atoms)], dtype=torch.long, device=self.device), + "offset" : torch.tensor(offset, dtype=torch.double, device=self.device), + "scaling" : torch.eye(3, dtype=torch.double, device=self.device).view(1, 3, 3), + "batch" : torch.zeros(len(atoms), dtype=torch.long, device=self.device), } self.model(data, properties, create_graph=False) @@ -61,12 +57,13 @@ class MiaoCalculator(Calculator): if "forces" in properties: self.results["forces"] = data["forces_p"].detach().cpu().numpy() if "stress" in properties: - raise Exception("ni shi gu yi zhao cha shi bu shi?") - # virial = np.sum( - # np.array(self.calc.getVirials()).reshape(9, -1), axis=1) - # if sum(atoms.get_pbc()) > 0: - # stress = -0.5 * (virial.copy() + - # virial.copy().T) / atoms.get_volume() - # self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]] - # else: - # raise PropertyNotImplementedError + virial = data["virial_p"].detach().cpu().numpy().reshape(-1) + if sum(atoms.get_pbc()) > 0: + stress = 0.5 * (virial.copy() + virial.copy().T) / atoms.get_volume() + self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]] + else: + raise PropertyNotImplementedError + if "dipole" in properties: + self.results["dipole"] = data["dipole_p"].detach().cpu().numpy() + if "polarizability" in properties: + self.results["polarizability"] = data["polarizability_p"].detach().cpu().numpy() diff --git a/interface/lammps/cmake/CMakeLists.txt b/interface/lammps/cmake/CMakeLists.txt index 61abcf6f2ce4e77489a9f6a56fbd094dcd853760..0b2f139b0ddafed7b59031b226b2127737c94089 100644 --- a/interface/lammps/cmake/CMakeLists.txt +++ b/interface/lammps/cmake/CMakeLists.txt @@ -50,6 +50,7 @@ endif() include(LAMMPSUtils) get_lammps_version(${LAMMPS_SOURCE_DIR}/version.h LAMMPS_VERSION_NUMBER) +add_compile_definitions(LAMMPS_VERSION_NUMBER=${LAMMPS_VERSION_NUMBER}) include(PreventInSourceBuilds) diff --git a/interface/lammps/src/.pair_miao.cpp.swp b/interface/lammps/src/.pair_miao.cpp.swp deleted file mode 100644 index 7c9fda6926489cbfa8b95ae2296e8a11da72299b..0000000000000000000000000000000000000000 Binary files a/interface/lammps/src/.pair_miao.cpp.swp and /dev/null differ diff --git a/interface/lammps/src/pair_miao.cpp b/interface/lammps/src/pair_miao.cpp index 7282572ab50badd59e04d27d408754e366dde920..f66d0938ec0a1b5a41ef4b604ebde2d58d81ef45 100644 --- a/interface/lammps/src/pair_miao.cpp +++ b/interface/lammps/src/pair_miao.cpp @@ -165,8 +165,9 @@ void PairMIAO::compute(int eflag, int vflag) int numneigh_atom = accumulate(numneigh, numneigh + inum , 0); std::vector<double> cart(nall * 3); - std::vector<long> atom_index(numneigh_atom * 2); - std::vector<long> ghost_neigh(numneigh_atom); + std::vector<long> idx_i(numneigh_atom); // atom i + std::vector<long> idx_j(numneigh_atom); // atom j + std::vector<long> ghost_neigh(numneigh_atom); // ghost atom j for calculate distance std::vector<long> local_species(inum); std::vector<long> batch(nall); double dx, dy, dz, d2; @@ -192,8 +193,8 @@ void PairMIAO::compute(int eflag, int vflag) d2 = dx * dx + dy * dy + dz * dz; if (d2 < cutoffsq) { - atom_index[totneigh * 2] = i; - atom_index[totneigh * 2 + 1] = atom->map(tag[j]); + idx_i[totneigh] = i; + idx_j[totneigh] = atom->map(tag[j]); ghost_neigh[totneigh] = j; ++totneigh; } @@ -233,17 +234,19 @@ void PairMIAO::compute(int eflag, int vflag) species_file.close(); */ auto cart_ = torch::from_blob(cart.data(), {nall, 3}, option1).to(device, true).to(tensor_type); - auto atom_index_ = torch::from_blob(atom_index.data(), {totneigh, 2}, option2).transpose(1, 0).to(device, true); + auto idx_i_ = torch::from_blob(idx_i.data(), {totneigh}, option2).to(device, true); + auto idx_j_ = torch::from_blob(idx_j.data(), {totneigh}, option2).to(device, true); auto ghost_neigh_ = torch::from_blob(ghost_neigh.data(), {totneigh}, option2).to(device, true); auto local_species_ = torch::from_blob(local_species.data(), {inum}, option2).to(device, true); auto batch_ = torch::from_blob(batch.data(), {inum}, option2).to(device, true); batch_data.insert("coordinate", cart_); batch_data.insert("atomic_number", local_species_); - batch_data.insert("edge_index", atom_index_); + batch_data.insert("idx_i", idx_i_); + batch_data.insert("idx_j", idx_j_); batch_data.insert("batch", batch_); batch_data.insert("ghost_neigh", ghost_neigh_); - torch::IValue output = module.forward({batch_data, properties}); + torch::IValue output = module.forward({batch_data, properties, false}); results = c10::impl::toTypedDict<std::string, torch::Tensor>(output.toGenericDict()); auto forces_tensor = results.at("forces_p").to(torch::kDouble).cpu().reshape({-1}); diff --git a/setup.py b/setup.py index b8f27133cb7f3f95df124c0f94e8849db5d97bc3..aca606ac03386191d87c3b63ef1a6d930d344584 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( "ase", "pyyaml", "torch", - "torch-geometric", + "lightning", ], license="MIT", description="MiaoNet: Moment tensor InterAggregate Operation Net", diff --git a/tensornet/data/__init__.py b/tensornet/data/__init__.py index b26dfd0230b04c60d1319e9a939ddc17d16d8c82..a4faf1c54f074b114b318e062a999e305e6e3528 100644 --- a/tensornet/data/__init__.py +++ b/tensornet/data/__init__.py @@ -1,4 +1,5 @@ from .base import * -from .md17 import * -from .asedata import * -from .ptdata import * +# from .md17 import * +from .asedata import ASEData, ASEDBData +# from .ptdata import * +from .data_interface import LitAtomsDataset \ No newline at end of file diff --git a/tensornet/data/asedata.py b/tensornet/data/asedata.py index e133e2b22e5d3e557c6bb443538a4afa34f3d28d..591dd60823a262b22624997cbfe0c81239903df8 100644 --- a/tensornet/data/asedata.py +++ b/tensornet/data/asedata.py @@ -1,44 +1,70 @@ -from .base import AtomsData -from ..utils import progress_bar +from .base import AtomsDataset from typing import List, Optional from ase import Atoms from ase.io import read -import torch -import os +from ase.db import connect -class ASEData(AtomsData): +class ASEData(AtomsDataset): + def __init__(self, - frames : Optional[List[Atoms]]=None, - root : Optional[str]=None, - name : Optional[str]=None, - format : Optional[str]=None, - cutoff : float=4.0, - device : str="cpu", + frames : Optional[List[Atoms]]=None, + indices : Optional[List[int]]=None, + properties : Optional[List[str]]=['energy', 'forces'], + cutoff : float=4.0, ) -> None: - self.cutoff = cutoff - self.device = device - if frames is None: - frames = read(os.path.join(root, name), format=format, index=':') + super().__init__(indices=indices, cutoff=cutoff) self.frames = frames - self.name = name or "processed" - if root is None: - root = os.getcwd() - super().__init__(root) - self.data, self.slices = torch.load(self.processed_paths[0], map_location=device) - - @property - def processed_file_names(self) -> str: - return f"{self.name}.pt" - - def process(self): - n_data = len(self.frames) - data_list = [] - for i in range(n_data): - progress_bar(i, n_data) - data = self.atoms_to_graph(self.frames[i], self.cutoff, self.device) - data_list.append(data) - torch.save(self.collate(data_list), self.processed_paths[0]) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({len(self)}, name='{self.name}')" + self.properties = properties + + def __len__(self): + if self.indices is None: + return len(self.frames) + else: + return len(self.indices) + + def __getitem__(self, idx): + if self.indices is not None: + idx = self.indices[idx] + data = self.atoms_to_data(self.frames[idx], + properties=self.properties, + cutoff=self.cutoff) + return data + + def extend(self, frames): + self.frames.extend(frames) + + +class ASEDBData(AtomsDataset): + + def __init__(self, + datapath : Optional[List[Atoms]]=None, + indices : Optional[List[int]]=None, + properties : Optional[List[str]]=['energy', 'forces'], + cutoff : float=4.0, + ) -> None: + super().__init__(indices=indices, cutoff=cutoff) + self.datapath = datapath + self.conn = connect(self.datapath, use_lock_file=False) + self.properties = properties + + def __len__(self): + if self.indices is None: + return self.conn.count() + else: + return len(self.indices) + + def __getitem__(self, idx): + if self.indices is not None: + idx = int(self.indices[idx]) + row = self.conn.get(idx + 1) + atoms = Atoms(numbers=row['numbers'], + cell=row['cell'], + positions=row['positions'], + pbc=row['pbc'], + info=row.data + ) + data = self.atoms_to_data(atoms, + properties=self.properties, + cutoff=self.cutoff) + return data diff --git a/tensornet/data/base.py b/tensornet/data/base.py index 82e9a9a83105d0c942eafb4e2908557483dc3b51..6682c1eb232e7ca160fe91aae8628a567d29ea4e 100644 --- a/tensornet/data/base.py +++ b/tensornet/data/base.py @@ -1,87 +1,91 @@ import torch -import os +import copy +import abc import numpy as np from tensornet.utils import EnvPara -from torch_geometric.data import Data, InMemoryDataset +from torch.utils.data import Dataset from ase.neighborlist import neighbor_list +from typing import Optional, List +# TODO: offset and scaling for different condition +class AtomsDataset(Dataset, abc.ABC): -class AtomsData(InMemoryDataset): @staticmethod - def atoms_to_graph(atoms, cutoff, device='cpu'): + def atoms_to_data(atoms, cutoff, properties=['energy', 'forces']): dim = len(atoms.get_cell()) idx_i, idx_j, offset = neighbor_list("ijS", atoms, cutoff, self_interaction=False) - bonds = np.array([idx_i, idx_j]) offset = np.array(offset) @ atoms.get_cell() - index = torch.arange(len(atoms), dtype=torch.long, device=device) - atomic_number = torch.tensor(atoms.numbers, dtype=torch.long, device=device) - edge_index = torch.tensor(bonds, dtype=torch.long, device=device) - offset = torch.tensor(offset, dtype=EnvPara.FLOAT_PRECISION, device=device) - coordinate = torch.tensor(atoms.positions, dtype=EnvPara.FLOAT_PRECISION, device=device) - scaling = torch.eye(dim, dtype=EnvPara.FLOAT_PRECISION, device=device).view(1, dim, dim) - n_atoms = torch.tensor(len(atoms), dtype=EnvPara.FLOAT_PRECISION, device=device) - graph = Data(x=index, - atomic_number=atomic_number, - edge_index=edge_index, - offset=offset, - coordinate=coordinate, - n_atoms=n_atoms, - scaling=scaling, - ) + + data = { + "atomic_number": torch.tensor(atoms.numbers, dtype=torch.long), + "idx_i": torch.tensor(idx_i, dtype=torch.long), + "idx_j": torch.tensor(idx_j, dtype=torch.long), + "coordinate": torch.tensor(atoms.positions, dtype=EnvPara.FLOAT_PRECISION), + "n_atoms": torch.tensor([len(atoms)], dtype=torch.long), + "offset": torch.tensor(offset, dtype=EnvPara.FLOAT_PRECISION), + "scaling": torch.eye(dim, dtype=EnvPara.FLOAT_PRECISION).view(1, dim, dim) + } + padding_shape = { 'site_energy' : (len(atoms)), 'energy' : (1), 'forces' : (len(atoms), dim), 'virial' : (1, dim, dim), 'dipole' : (1, dim), + 'polarizability': (1, dim, dim) } - for key in ['site_energy', 'energy', 'forces', 'virial', 'dipole']: + for key in properties: if key in atoms.info: - graph[key + '_t'] = torch.tensor(atoms.info[key], dtype=EnvPara.FLOAT_PRECISION, device=device).reshape(padding_shape[key]) - graph['has_' + key] = True + data[key + '_t'] = torch.tensor(atoms.info[key], dtype=EnvPara.FLOAT_PRECISION).reshape(padding_shape[key]) + data[key + '_weight'] = torch.ones(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION) else: - graph[key + '_t'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION, device=device) - graph['has_' + key] = False - return graph + data[key + '_t'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION) + data[key + '_weight'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION) + return data + + def __init__(self, + indices: Optional[List[int]]=None, + cutoff : float=4.0, + ) -> None: + self.indices = indices + self.cutoff = cutoff - @property - def processed_dir(self) -> str: - return os.path.join(self.root, f"processed_{self.cutoff:.2f}".replace(".", "_")) + def __len__(self): + if self.indices: + return len(self.indices) - @property - def per_energy_mean(self): - per_energy = self.data["energy_t"] / self.data["n_atoms"] - return torch.mean(per_energy) + @abc.abstractmethod + def __getitem__(self, idx: int): + pass - @property - def per_energy_std(self): - per_energy = self.data["energy_t"] / self.data["n_atoms"] - return torch.std(per_energy) + def subset(self, indices: List[int]): + ds = copy.copy(self) + if ds.indices: + ds.indices = [ds.indices[i] for i in indices] + else: + ds.indices = indices + return ds - @property - def forces_std(self): - return torch.std(self.data["forces_t"]) +def atoms_collate_fn(batch): - @property - def n_neighbor_mean(self): - n_neighbor = self.data['edge_index'].shape[1] / len(self.data['x']) - return n_neighbor + elem = batch[0] + coll_batch = {} - @property - def all_elements(self): - return torch.unique(self.data['atomic_number']) + for key in elem: + if key not in ["idx_i", "idx_j"]: + coll_batch[key] = torch.cat([d[key] for d in batch], dim=0) - def load(self, name: str, device: str="cpu") -> None: - self.data, self.slices = torch.load(name, map_location=device) + # idx_i and idx_j should to be converted like + # [0, 0, 1, 1] + [0, 0, 1, 2] -> [0, 0, 1, 1, 2, 2, 3, 4] + for key in ["idx_i", "idx_j"]: + coll_batch[key] = torch.cat( + [batch[i][key] + torch.sum(coll_batch["n_atoms"][:i]) for i in range(len(batch))], dim=0 + ) - def load_split(self, train_split: str, test_split: str): - train_idx = np.loadtxt(train_split, dtype=int) - test_idx = np.loadtxt(test_split, dtype=int) - return self.copy(train_idx), self.copy(test_idx) + coll_batch["batch"] = torch.repeat_interleave( + torch.arange(len(batch)), + repeats=coll_batch["n_atoms"].to(torch.long), + dim=0 + ) - def random_split(self, train_num: int, test_num: int): - assert train_num + test_num <= len(self) - idx = np.random.choice(len(self), train_num + test_num, replace=False) - train_idx = idx[:train_num] - test_idx = idx[train_num:] - return self.copy(train_idx), self.copy(test_idx) + return coll_batch diff --git a/tensornet/data/data_interface.py b/tensornet/data/data_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a28b1c142b9037ba4a408d17a16a20e97a0b63 --- /dev/null +++ b/tensornet/data/data_interface.py @@ -0,0 +1,182 @@ +import logging +import os +import numpy as np +from copy import copy +from typing import Optional, List, Dict, Tuple, Union +from . import * +from ase.io import read +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + + +log = logging.getLogger(__name__) + + +class LitAtomsDataset(pl.LightningDataModule): + + def __init__(self, p_dict): + super().__init__() + self.p_dict = p_dict + self._train_dataloader = None + self._test_dataloader = None + self.stats = {} + + def setup(self, stage: Optional[str] = None): + self.dataset = self.get_dataset() + self.trainset, self.testset = self.split_dataset() + # self.calculate_stats() + + def get_dataset(self): + data_dict = self.p_dict['Data'] + # if data_dict['type'] == 'rmd17': + # dataset = RevisedMD17(data_dict['path'], + # data_dict['name'], + # cutoff=p_dict['cutoff'], + # device=p_dict['device']) + if data_dict['type'] == 'ase': + if 'name' in data_dict: + frames = read(os.path.join(data_dict['path'], data_dict['name']), index=':') + else: + frames = [] + dataset = ASEData(frames=frames, + properties=self.p_dict['Train']['targetProp'], + cutoff=self.p_dict['cutoff']) + elif data_dict['type'] == 'ase-db': + dataset = ASEDBData(datapath=os.path.join(data_dict['path'], data_dict['name']), + properties=self.p_dict['Train']['targetProp'], + cutoff=self.p_dict['cutoff']) + return dataset + + def split_dataset(self): + data_dict = self.p_dict['Data'] + if ("trainSplit" in data_dict) and ("testSplit" in data_dict): + log.info("Load split from {} and {}".format(data_dict["trainSplit"], data_dict["testSplit"])) + train_idx = np.loadtxt(data_dict["trainSplit"], dtype=int) + test_idx = np.loadtxt(data_dict["testSplit"], dtype=int) + return self.dataset.subset(train_idx), self.dataset.subset(test_idx) + + if ("trainNum" in data_dict) and (("testNum" in data_dict)): + log.info("Random split, train num: {}, test num: {}".format(data_dict["trainNum"], data_dict["testNum"])) + assert data_dict['trainNum'] + data_dict['testNum'] <= len(self.dataset) + idx = np.random.choice(len(self.dataset), data_dict['trainNum'] + data_dict['testNum'], replace=False) + train_idx = idx[:data_dict['trainNum']] + test_idx = idx[data_dict['trainNum']:] + return self.dataset.subset(train_idx), self.dataset.subset(test_idx) + + if ("trainSet" in data_dict) and ("testSet" in data_dict): + assert data_dict['type'] == 'ase', "trainSet and testSet must can be read by ase!" + trainset = read(os.path.join(data_dict['path'], data_dict['trainSet']), index=':') + self.dataset.extend(trainset) + testset = read(os.path.join(data_dict['path'], data_dict['testSet']), index=':') + self.dataset.extend(testset) + train_idx = [i for i in range(len(trainset))] + test_idx = [i for i in range(len(trainset), len(trainset) + len(testset))] + return self.dataset.subset(train_idx), self.dataset.subset(test_idx) + + raise Exception("No splitting!") + + def train_dataloader(self): + if self._train_dataloader is None: + self._train_dataloader = DataLoader(self.trainset, + batch_size=self.p_dict["Data"]["trainBatch"], + shuffle=True, + collate_fn=atoms_collate_fn, + num_workers=self.p_dict["Data"]["numWorkers"], + pin_memory=self.p_dict["Data"]["pinMemory"]) + log.debug(f'numWorkers: {self.p_dict["Data"]["numWorkers"]}') + return self._train_dataloader + + def val_dataloader(self): + return self.test_dataloader() + + def test_dataloader(self): + if self._test_dataloader is None: + self._test_dataloader = DataLoader(self.testset, + batch_size=self.p_dict["Data"]["testBatch"], + shuffle=False, + collate_fn=atoms_collate_fn, + num_workers=self.p_dict["Data"]["numWorkers"], + pin_memory=self.p_dict["Data"]["pinMemory"]) + return self._test_dataloader + + def calculate_stats(self): + # To be noticed, we assume that the average force is always 0, + # so the final result may differ from the actual variance + + N_batch = 0 + N_forces = 0 + per_energy_mean = 0. + n_neighbor_mean = 0. + per_energy_std = 0. + forces_std = 0. + all_elements = torch.tensor([], dtype=torch.long)#, device=self.p_dict['device']) + + for batch_data in self.train_dataloader(): + # all elemetns + all_elements = torch.unique(torch.cat((all_elements, batch_data['atomic_number']))) + + # per_energy_mean + batch_size = batch_data["energy_t"].numel() + pe = batch_data["energy_t"] / batch_data["n_atoms"] + pe_mean = torch.mean(pe) + delta_pe_mean = pe_mean - per_energy_mean + per_energy_mean += delta_pe_mean * batch_size / (N_batch + batch_size) + + # per_energy_std + pe_m2 = torch.sum((pe - pe_mean) ** 2) + pe_corr = batch_size * N_batch / (N_batch + batch_size) + per_energy_std += pe_m2 + delta_pe_mean ** 2 * pe_corr + + # n_neighbor_mean + nn_mean = batch_data["idx_i"].shape[0] / torch.sum(batch_data["n_atoms"]) + delta_nn_mean = nn_mean - n_neighbor_mean + n_neighbor_mean += delta_nn_mean * batch_size / (N_batch + batch_size) + N_batch += batch_size + + # forces_std + if 'forces_t' in batch_data: + forces_size = batch_data["forces_t"].numel() + forces_m2 = torch.sum(batch_data["forces_t"] ** 2) + forces_std += forces_m2 + N_forces += forces_size + + per_energy_std = torch.sqrt(per_energy_std / N_batch) + if N_forces > 0: + forces_std = torch.sqrt(forces_std / N_forces) + + self.stats["per_energy_mean"] = per_energy_mean + self.stats["per_energy_std"] = per_energy_std + self.stats["n_neighbor_mean"] = n_neighbor_mean + self.stats["forces_std"] = forces_std + self.stats["all_elements"] = all_elements + + @property + def per_energy_mean(self): + if "per_energy_mean" not in self.stats: + self.calculate_stats() + return self.stats["per_energy_mean"] + + @property + def per_energy_std(self): + if "per_energy_std" not in self.stats: + self.calculate_stats() + return self.stats["per_energy_std"] + + @property + def forces_std(self): + if "forces_std" not in self.stats: + self.calculate_stats() + return self.stats["forces_std"] + + @property + def n_neighbor_mean(self): + if "forces_std" not in self.stats: + self.calculate_stats() + return self.stats["n_neighbor_mean"] + + @property + def all_elements(self): + if "all_elements" not in self.stats: + self.calculate_stats() + return self.stats["all_elements"] diff --git a/tensornet/entrypoints/eval.py b/tensornet/entrypoints/eval.py index 9c66d98758ef567ae067e73f1dfdbea8273b1401..5ce7b04ed1c18013a766a2b13f52b9c8562c72e1 100644 --- a/tensornet/entrypoints/eval.py +++ b/tensornet/entrypoints/eval.py @@ -1,8 +1,8 @@ -from ase.io import read +from ase.io import read import numpy as np import torch -from torch_geometric.loader import DataLoader -from tensornet.data import ASEData, PtData +from torch.utils.data import DataLoader +from tensornet.data import ASEData, ASEDBData, atoms_collate_fn def eval(model, data_loader, properties): @@ -23,18 +23,33 @@ def eval(model, data_loader, properties): return None -def main(*args, cutoff=None, model='model.pt', device='cpu', dataset='data.traj', format=None, - properties=["energy", "forces"], batchsize=32, **kwargs): - if '.pt' in dataset: - dataset = PtData(dataset, device=device) +def main(*args, modelfile='model.pt', indices=None, device='cpu', datafile='data.traj', + format=None, properties=["energy", "forces"], batchsize=32, num_workers=4, pin_memory=True, + **kwargs): + model = torch.load(modelfile, map_location=device) + cutoff = float(model.cutoff.detach().cpu().numpy()) + if indices is not None: + indices = np.loadtxt(indices, dtype=int) + + if '.db' in datafile: + dataset = ASEDBData(datapath=datafile, + indices=indices, + properties=properties, + cutoff=cutoff) else: - assert cutoff is not None, "Must have cutoff!!" - frames = read(dataset, index=':', format=format) - dataset = ASEData(frames, name="eval_process", cutoff=cutoff, device=device) - data_loader = DataLoader(dataset, batch_size=batchsize, shuffle=False) - model = torch.load(model) - eval(model, data_loader, properties) + frames = read(datafile, index=':', format=format) + dataset = ASEData(frames=frames, + indices=indices, + cutoff=cutoff, + properties=properties) + data_loader = DataLoader(dataset, + batch_size=batchsize, + shuffle=False, + collate_fn=atoms_collate_fn, + num_workers=num_workers, + pin_memory=pin_memory) + eval(model, data_loader, properties) if __name__ == "__main__": main() diff --git a/tensornet/entrypoints/main.py b/tensornet/entrypoints/main.py index 41887de2ad010a6c3941d1e531a80416d77dadc3..c75c3350f355af1f7d7c93775bd7ca8be4bf24ff 100644 --- a/tensornet/entrypoints/main.py +++ b/tensornet/entrypoints/main.py @@ -62,22 +62,22 @@ def parse_args(): # eval parser_eval = subparsers.add_parser( "eval", - help="train", + help="eval", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser_eval.add_argument( "-m", - "--model", - type=str, + "--modelfile", + type=str, default="model.pt", help="model" ) parser_eval.add_argument( - "-c", - "--cutoff", - type=float, + "-i", + "--indices", + type=str, default=None, - help="cutoff" + help="indices" ) parser_eval.add_argument( "--device", @@ -87,22 +87,22 @@ def parse_args(): ) parser_eval.add_argument( "-d", - "--dataset", - type=str, + "--datafile", + type=str, default="data.traj", help="dataset" ) parser_eval.add_argument( "-f", "--format", - type=str, + type=str, default=None, help="format" ) parser_eval.add_argument( "-b", "--batchsize", - type=int, + type=int, default=32, help="batchsize" ) @@ -114,6 +114,17 @@ def parse_args(): default=["energy", "forces"], help="target properties" ) + parser_eval.add_argument( + "--num_workers", + type=int, + default=4, + help="num workder" + ) + parser_eval.add_argument( + "--pin_memory", + action="store_true", + help="pin memory" + ) # clean parser_clean = subparsers.add_parser( "clean", diff --git a/tensornet/entrypoints/train.py b/tensornet/entrypoints/train.py index 356b8622587e2cc9655d6a042cb238394c52ed5c..d8f43bd608f1b9988194a0a83c186aceb623dff0 100644 --- a/tensornet/entrypoints/train.py +++ b/tensornet/entrypoints/train.py @@ -1,23 +1,42 @@ import logging, time, yaml, os import numpy as np -from torch_geometric.loader import DataLoader import torch from torch import nn +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor import torch.nn.functional as F from torch.optim.swa_utils import AveragedModel from ase.data import atomic_numbers from ..utils import setup_seed -from ..loss import Loss, MissingValueLoss, ForceScaledLoss -from ..model import MiaoNet +from ..model import MiaoNet, LitAtomicModule from ..layer.cutoff import * from ..layer.embedding import AtomicEmbedding from ..layer.radial import * -from ..data import * +from ..data import LitAtomsDataset +torch.set_float32_matmul_precision("high") log = logging.getLogger(__name__) +class SaveModelCheckpoint(ModelCheckpoint): + """ + Saves model.pt for eval + """ + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + modelpath = filepath[:-4] + "pt" + if trainer.is_global_zero: + torch.save(trainer.lightning_module.model, modelpath) + + def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + super()._remove_checkpoint(trainer, filepath) + modelpath = filepath[:-4] + "pt" + if trainer.is_global_zero: + if os.path.exists(modelpath): + os.remove(modelpath) + + def update_dict(d1, d2): for key in d2: if key in d1 and isinstance(d1[key], dict): @@ -27,6 +46,41 @@ def update_dict(d1, d2): return d1 +def get_stats(data_dict, dataset): + + if type(data_dict["mean"]) is float: + mean = data_dict["mean"] + else: + try: + mean = dataset.per_energy_mean.detach().cpu().numpy() + except: + mean = 0. + + if data_dict["std"] == "force": + std = dataset.forces_std.detach().cpu().numpy() + elif data_dict["std"] == "energy": + std = dataset.per_energy_std.detach().cpu().numpy() + else: + assert type(data_dict["std"]) is float, "std must be 'force', 'energy' or a float!" + std = data_dict["std"] + + if type(data_dict["nNeighbor"]) is float: + n_neighbor = data_dict["nNeighbor"] + else: + n_neighbor = dataset.n_neighbor_mean.detach().cpu().numpy() + + if isinstance(data_dict["elements"], list): + elements = data_dict["elements"] + else: + elements = list(dataset.all_elements.detach().cpu().numpy()) + + log.info(f"mean : {mean}") + log.info(f"std : {std}") + log.info(f"n_neighbor : {n_neighbor}") + log.info(f"all_elements : {elements}") + return mean, std, n_neighbor, elements + + def get_cutoff(p_dict): cutoff = p_dict['cutoff'] cut_dict = p_dict['Model']['CutoffLayer'] @@ -75,6 +129,9 @@ def get_model(p_dict, elements, mean, std, n_neighbor): target_way["site_energy"] = 0 if "dipole" in target: target_way["dipole"] = 1 + if "polarizability" in target: + target_way["polar_diagonal"] = 0 + target_way["polar_off_diagonal"] = 2 if "direct_forces" in target: assert "forces" not in target_way, "Cannot learn forces and direct_forces at the same time" target_way["direct_forces"] = 1 @@ -96,218 +153,6 @@ def get_model(p_dict, elements, mean, std, n_neighbor): return model -def get_dataset(p_dict): - data_dict = p_dict['Data'] - if data_dict['type'] == 'rmd17': - dataset = RevisedMD17(data_dict['path'], - data_dict['name'], - cutoff=p_dict['cutoff'], - device=p_dict['device']) - if data_dict['type'] == 'ase': - if 'name' in data_dict: - dataset = ASEData(root=data_dict['path'], - name=data_dict['name'], - cutoff=p_dict['cutoff'], - device=p_dict['device']) - else: - dataset = None - return dataset - - -def split_dataset(dataset, p_dict): - data_dict = p_dict['Data'] - if ("trainSplit" in data_dict) and ("testSplit" in data_dict): - log.info("Load split from {} and {}".format(data_dict["trainSplit"], data_dict["testSplit"])) - return dataset.load_split(data_dict["trainSplit"], data_dict["testSplit"]) - if ("trainNum" in data_dict) and (("testNum" in data_dict)): - log.info("Random split, train num: {}, test num: {}".format(data_dict["trainNum"], data_dict["testNum"])) - return dataset.random_split(data_dict["trainNum"], data_dict["testNum"]) - if ("trainSet" in data_dict) and ("testSet" in data_dict): - assert data_dict['type'] == 'ase', "trainset must can be read by ase!" - trainset = ASEData(root=data_dict['path'], - name=data_dict['trainSet'], - cutoff=p_dict['cutoff'], - device=p_dict['device']) - testset = ASEData(root=data_dict['path'], - name=data_dict['testSet'], - cutoff=p_dict['cutoff'], - device=p_dict['device']) - return trainset, testset - raise Exception("No splitting!") - - -def get_loss_calculator(p_dict): - train_dict = p_dict['Train'] - target = train_dict['targetProp'] - weight = train_dict['weight'] - weights = {p: w for p, w in zip(target, weight)} - if "direct_forces" in weights: - weights["forces"] = weights.pop("direct_forces") # direct forces use the same key of forces - if train_dict['allowMissing']: - # TODO: rewrite loss function - if train_dict['forceScale'] > 0: - raise Exception("Now forceScale not support allowMissing!") - return MissingValueLoss(weights, loss_fn=F.mse_loss) - else: - if train_dict['forceScale'] > 0: - return ForceScaledLoss(weights, loss_fn=F.mse_loss, scaled=train_dict['forceScale']) - else: - return Loss(weights, loss_fn=F.mse_loss) - - -def eval(model, properties, loss_calculator, data_loader): - total = [] - prop_loss = {} - for i_batch, batch_data in enumerate(data_loader): - model(batch_data, properties, create_graph=False) - loss, loss_dict = loss_calculator.get_loss(batch_data, verbose=True) - total.append(loss.detach().cpu().numpy()) - for prop in loss_dict: - if prop not in prop_loss: - prop_loss[prop] = [] - prop_loss[prop].append(loss_dict[prop].detach().cpu().numpy()) - t1 = np.sqrt(np.mean(total)) - for prop in loss_dict: - prop_loss[prop] = np.sqrt(np.mean(prop_loss[prop])) - return t1, prop_loss - - -def train(model, loss_calculator, optimizer, lr_scheduler, ema, train_loader, test_loader, p_dict): - min_loss = 10000 - t2 = 10000 - t = time.time() - content = f"{'epoch':^10}|{'time':^10}|{'lr':^10}|{'total':^21}" - for prop in p_dict["Train"]['targetProp']: - content += f"|{prop:^21}" - log.info(content) - for epoch in range(p_dict["Train"]["epoch"]): - for i_batch, batch_data in enumerate(train_loader): - optimizer.zero_grad() - model(batch_data, p_dict["Train"]['targetProp']) - loss = loss_calculator.get_loss(batch_data) - loss.backward() - if p_dict["Train"]["maxGradNorm"] is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=p_dict["Train"]["maxGradNorm"]) - optimizer.step() - if ema is not None: - ema.update_parameters(model) - lr_scheduler.step(epoch=epoch, metrics=t2) - if epoch % p_dict["Train"]["logInterval"] == 0: - lr = optimizer.param_groups[0]["lr"] - t1, prop_loss1 = eval(model, p_dict["Train"]['targetProp'], loss_calculator, train_loader) - if p_dict["Train"]["evalTest"]: - if ema: - t2, prop_loss2 = eval(ema, p_dict["Train"]['targetProp'], loss_calculator, test_loader) - else: - t2, prop_loss2 = eval(model, p_dict["Train"]['targetProp'], loss_calculator, test_loader) - else: - t2, prop_loss2 = t1, prop_loss1 - content = f"{epoch:^10}|{time.time() - t:^10.2f}|{lr:^10.2e}|{t1:^10.4f}/{t2:^10.4f}" - t = time.time() - for prop in p_dict["Train"]['targetProp']: - prop = "forces" if prop == "direct_forces" else prop - content += f"|{prop_loss1[prop]:^10.4f}/{prop_loss2[prop]:^10.4f}" - log.info(content) - if t2 < min_loss: - min_loss = t2 - save_checkpoints(p_dict["outputDir"], "best", model, ema, optimizer, lr_scheduler) - if epoch > p_dict["Train"]["saveStart"] and epoch % p_dict["Train"]["saveInterval"] == 0: - save_checkpoints(p_dict["outputDir"], epoch, model, ema, optimizer, lr_scheduler) - - -def save_checkpoints(path, name, model, ema, optimizer, lr_scheduler): - checkpoint = { - "model": model.state_dict(), - "optimizer": optimizer.state_dict(), - "lr_scheduler": lr_scheduler.state_dict(), - } - if ema is not None: - checkpoint["ema"] = ema.state_dict() - torch.save(checkpoint, os.path.join(path, f"state_dict-{name}.pt")) - torch.save(model, os.path.join(path, f"model-{name}.pt")) - - -def get_optimizer(p_dict, model): - opt_dict = p_dict["Train"]["Optimizer"] - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.son_equivalent_layers.named_parameters(): - if "weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.embedding_layer.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": opt_dict["weightDecay"], - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "readouts", - "params": model.readout_layer.parameters(), - "weight_decay": 0.0, - }, - ], - lr=opt_dict["learningRate"], - amsgrad=opt_dict["amsGrad"], - ) - - if opt_dict['type'] == "Adam": - return torch.optim.Adam(**param_options) - elif opt_dict['type'] == "AdamW": - return torch.optim.AdamW(**param_options) - else: - raise Exception("Unsupported optimizer: {}!".format(opt_dict["type"])) - -def get_lr_scheduler(p_dict, optimizer): - class LrScheduler: - def __init__(self, p_dict, optimizer) -> None: - lr_dict = p_dict["Train"]["LrScheduler"] - self.mode = lr_dict['type'] - if lr_dict['type'] == "exponential": - self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, - gamma=lr_dict['gamma']) - elif lr_dict['type'] == "reduceOnPlateau": - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, - factor=lr_dict['lrFactor'], - patience=lr_dict['patience']) - elif lr_dict['type'] == "constant": - self.lr_scheduler = None - else: - raise Exception("Unsupported LrScheduler: {}!".format(lr_dict['type'])) - - def step(self, metrics=None, epoch=None): - if self.mode == "exponential": - self.lr_scheduler.step(epoch=epoch) - elif self.mode == "reduceOnPlateau": - self.lr_scheduler.step(metrics=metrics, epoch=epoch) - - def state_dict(self): - if self.lr_scheduler: - return {key: value for key, value in self.lr_scheduler.__dict__.items() if key != 'optimizer'} - - def load_state_dict(self, state_dict): - if self.lr_scheduler: - self.lr_scheduler.__dict__.update(state_dict) - - def __repr__(self): - return self.lr_scheduler.__repr__() - - return LrScheduler(p_dict, optimizer) - - def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, **kwargs): # Default values p_dict = { @@ -320,6 +165,11 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, "trainBatch": 32, "testBatch": 32, "std": "force", + "mean": None, + "nNeighbor": None, + "elements": None, + "numWorkers": 0, + "pinMemory": False, }, "Model": { "mode": "normal", @@ -342,20 +192,23 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, } }, "Train": { + "maxEpoch": 10000, + "maxStep": 1000000, + "learningRate": 0.001, "allowMissing": False, "targetProp": ["energy", "forces"], "weight": [0.1, 1.0], "forceScale": 0., - "logInterval": 100, - "saveInterval": 500, + "evalStepInterval": 50, + "evalEpochInterval": 1, + "logInterval": 50, "saveStart": 1000, "evalTest": True, - "maxGradNorm": None, + "gradClip": None, "Optimizer": { "type": "Adam", "amsGrad": True, "weightDecay": 0., - "learningRate": 0.001, }, "LrScheduler": { "type": "constant", @@ -373,65 +226,52 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, setup_seed(p_dict["seed"]) log.info("Using seed {}".format(p_dict["seed"])) - - dataset = get_dataset(p_dict) - trainset, testset = split_dataset(dataset, p_dict) - train_loader = DataLoader(trainset, batch_size=p_dict["Data"]["trainBatch"], shuffle=True) - test_loader = DataLoader(testset, batch_size=p_dict["Data"]["testBatch"], shuffle=False) - if dataset is None: - dataset = trainset - try: - mean = dataset.per_energy_mean.detach().cpu().numpy() - except: - mean = 0. - if p_dict["Data"]["std"] == "force": - std = dataset.forces_std.detach().cpu().numpy() - elif p_dict["Data"]["std"] == "energy": - std = dataset.per_energy_std.detach().cpu().numpy() - else: - std = p_dict["Data"]["std"] - n_neighbor = dataset.n_neighbor_mean - elements = list(dataset.all_elements.detach().cpu().numpy()) - log.info(f"mean : {mean}") - log.info(f"std : {std}") - log.info(f"n_neighbor : {n_neighbor}") - log.info(f"all_elements : {elements}") - if load_model is not None: + log.info(f"Preparing data...") + dataset = LitAtomsDataset(p_dict) + dataset.setup() + mean, std, n_neighbor, elements = get_stats(p_dict["Data"], dataset) + + if load_model is not None and 'ckpt' not in load_model: + log.info(f"Load model from {load_model}") model = torch.load(load_model) else: model = get_model(p_dict, elements, mean, std, n_neighbor) model.register_buffer('all_elements', torch.tensor(elements, dtype=torch.long)) model.register_buffer('cutoff', torch.tensor(p_dict["cutoff"], dtype=torch.float64)) - optimizer = get_optimizer(p_dict, model) - lr_scheduler = get_lr_scheduler(p_dict, optimizer) + if load_model is not None and 'ckpt' in load_model: + lit_model = LitAtomicModule.load_from_checkpoint(load_model, model=model, p_dict=p_dict) + else: + lit_model = LitAtomicModule(model=model, p_dict=p_dict) + + logger = pl.loggers.TensorBoardLogger(save_dir=p_dict["outputDir"]) + callbacks = [ + SaveModelCheckpoint( + dirpath=p_dict["outputDir"], + filename='{epoch}-{step}-{val_loss:.4f}', + save_top_k=5, + monitor="val_loss" + ), + LearningRateMonitor() + ] + trainer = pl.Trainer( + logger=logger, + callbacks=callbacks, + default_root_dir='.', + max_epochs=p_dict["Train"]["maxEpoch"], + max_steps=p_dict["Train"]["maxStep"], + enable_progress_bar=False, + log_every_n_steps=p_dict["Train"]["logInterval"], + val_check_interval=p_dict["Train"]["evalStepInterval"], + check_val_every_n_epoch=p_dict["Train"]["evalEpochInterval"], + gradient_clip_val=p_dict["Train"]["gradClip"], + ) if load_checkpoint is not None: - state_dict = torch.load(load_checkpoint) - model.load_state_dict(state_dict["model"]) - optimizer.load_state_dict(state_dict["optimizer"]) - lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) - - log.info(" Network Architecture ".center(100, "=")) - log.info(model) - log.info(f"Number of parameters: {sum([p.numel() for p in model.parameters()])}") - log.info(" Optimizer ".center(100, "=")) - log.info(optimizer) - # log.info(" LRScheduler ".center(100, "=")) - # log.info(lr_scheduler) - - ema_decay = p_dict["Train"]["emaDecay"] - if ema_decay > 0: - ema_avg = lambda averaged_para, para, n: ema_decay * averaged_para + (1 - ema_decay) * para - ema = AveragedModel(model=model, device=p_dict["device"], avg_fn=ema_avg, use_buffers=False) - # log.info(" ExponentialMovingAverage ".center(80, "=")) - # log.info(ema) + log.info(f"Load checkpoints from {load_checkpoint}") + trainer.fit(lit_model, datamodule=dataset, ckpt_path=load_checkpoint) else: - ema = None - log.info("=" * 100) - loss_calculator = get_loss_calculator(p_dict) - train(model, loss_calculator, optimizer, lr_scheduler, ema, train_loader, test_loader, p_dict) - + trainer.fit(lit_model, datamodule=dataset) if __name__ == "__main__": main() diff --git a/tensornet/layer/embedding.py b/tensornet/layer/embedding.py index 84dd1b1fceb28cb2af00088880127bbca3366edd..b6e4c88bdf8403a35218cc7399d7c5e51ce1e560 100644 --- a/tensornet/layer/embedding.py +++ b/tensornet/layer/embedding.py @@ -104,7 +104,8 @@ class BehlerG1(EmbeddingLayer): batch_data : Dict[str, torch.Tensor], ) -> torch.Tensor: n_atoms = batch_data['atomic_number'].shape[0] - idx_i, idx_j = batch_data['edge_index'] + idx_i = batch_data['idx_i'] + idx_j = batch_data['idx_j'] _, dij, _ = find_distances(batch_data) zij = batch_data['atomic_number'][idx_j].unsqueeze(-1) # [n_edge, 1] dij = dij.unsqueeze(-1) diff --git a/tensornet/layer/equivalent.py b/tensornet/layer/equivalent.py index f75aaee3ce7eeb8ae02403fd3cd8c43870656f9a..39b32ba47a9d9680acc8ac4a2b48e8b5e0f2fdb0 100644 --- a/tensornet/layer/equivalent.py +++ b/tensornet/layer/equivalent.py @@ -9,7 +9,7 @@ from torch import nn from typing import Dict, Callable, Union from .base import RadialLayer, CutoffLayer from .activate import TensorActivateDict -from ..utils import find_distances, expand_to, multi_outer_product, _scatter_add, find_moment +from ..utils import find_distances, find_moment, _scatter_add, _aggregate # input_tensors be like: @@ -53,10 +53,10 @@ class TensorLinear(nn.Module): class SimpleTensorAggregateLayer(nn.Module): - """ + """ In this type of layer, the rbf mixing only different if r_way different """ - def __init__(self, + def __init__(self, radial_fn : RadialLayer, n_channel : int, max_in_way : int=2, @@ -85,36 +85,23 @@ class SimpleTensorAggregateLayer(nn.Module): ) -> Dict[int, torch.Tensor]: # These 3 rows are required by torch script output_tensors = torch.jit.annotate(Dict[int, torch.Tensor], {}) - idx_i = batch_data['edge_index'][0] - idx_j = batch_data['edge_index'][1] + idx_i = batch_data['idx_i'] + idx_j = batch_data['idx_j'] n_atoms = batch_data['atomic_number'].shape[0] _, dij, uij = find_distances(batch_data) rbf_ij = self.radial_fn(dij) # [n_edge, n_rbf] - for r_way in range(self.max_r_way + 1): - fn = self.rbf_mixing_dict[str(r_way)](rbf_ij) # [n_edge, n_channel] + for r_way, rbf_mixing in self.rbf_mixing_dict.items(): + r_way = int(r_way) + fn = rbf_mixing(rbf_ij) # [n_edge, n_channel] # TODO: WHY!!!!!!!!!! CAO! # fn = fn * input_tensor_dict[0] moment_tensor = find_moment(batch_data, r_way) # [n_edge, n_dim, ...] - filter_tensor_ = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...] for in_way, out_way in self.inout_combinations[r_way]: input_tensor = input_tensors[in_way][idx_j] # [n_edge, n_channel, n_dim, n_dim, ...] - coupling_way = (in_way + r_way - out_way) // 2 - # method 1 - n_way = in_way + r_way - coupling_way + 2 - input_tensor = expand_to(input_tensor, n_way, dim=-1) - filter_tensor = expand_to(filter_tensor_, n_way, dim=2) - output_tensor = input_tensor * filter_tensor - # input_tensor: [n_edge, n_channel, n_dim, n_dim, ..., 1] - # filter_tensor: [n_edge, n_channel, 1, 1, ..., n_dim] - # with (in_way + r_way - coupling_way) dim after n_channel - # We should sum up (coupling_way) n_dim - if coupling_way > 0: - sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)] - output_tensor = torch.sum(output_tensor, dim=sum_axis) + output_tensor = _aggregate(moment_tensor, fn, input_tensor, in_way, r_way, out_way) output_tensor = _scatter_add(output_tensor, idx_i, dim_size=n_atoms) / self.norm_factor - # output_tensor = segment_coo(output_tensor, idx_i, dim_size=batch_data.num_nodes, reduce="sum") if out_way not in output_tensors: output_tensors[out_way] = output_tensor @@ -124,7 +111,7 @@ class SimpleTensorAggregateLayer(nn.Module): class TensorAggregateLayer(nn.Module): - def __init__(self, + def __init__(self, radial_fn : RadialLayer, n_channel : int, max_in_way : int=2, @@ -134,7 +121,7 @@ class TensorAggregateLayer(nn.Module): ) -> None: super().__init__() # get all possible "i, r, o" combinations - self.all_combinations = [] + self.all_combinations = {} self.rbf_mixing_dict = nn.ModuleDict() for in_way in range(max_in_way + 1): for r_way in range(max_r_way + 1): @@ -142,7 +129,7 @@ class TensorAggregateLayer(nn.Module): out_way = in_way + r_way - 2 * z_way if out_way <= max_out_way: comb = (in_way, r_way, out_way) - self.all_combinations.append(comb) + self.all_combinations[str(comb)] = comb self.rbf_mixing_dict[str(comb)] = nn.Linear(radial_fn.n_features, n_channel, bias=False) self.radial_fn = radial_fn @@ -154,36 +141,22 @@ class TensorAggregateLayer(nn.Module): ) -> Dict[int, torch.Tensor]: # These 3 rows are required by torch script output_tensors = torch.jit.annotate(Dict[int, torch.Tensor], {}) - idx_i = batch_data['edge_index'][0] - idx_j = batch_data['edge_index'][1] + idx_i = batch_data['idx_i'] + idx_j = batch_data['idx_j'] n_atoms = batch_data['atomic_number'].shape[0] _, dij, uij = find_distances(batch_data) rbf_ij = self.radial_fn(dij) # [n_edge, n_rbf] - for in_way, r_way, out_way in self.all_combinations: - fn = self.rbf_mixing_dict[str((in_way, r_way, out_way))](rbf_ij) # [n_edge, n_channel] + for comb, rbf_mixing in self.rbf_mixing_dict.items(): + in_way, r_way, out_way = self.all_combinations[comb] + fn = rbf_mixing(rbf_ij) # [n_edge, n_channel] # TODO: WHY!!!!!!!!!! CAO! # fn = fn * input_tensor_dict[0] moment_tensor = find_moment(batch_data, r_way) # [n_edge, n_dim, ...] - filter_tensor = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...] input_tensor = input_tensors[in_way][idx_j] # [n_edge, n_channel, n_dim, n_dim, ...] - coupling_way = (in_way + r_way - out_way) // 2 - # method 1 - n_way = in_way + r_way - coupling_way + 2 - input_tensor = expand_to(input_tensor, n_way, dim=-1) - filter_tensor = expand_to(filter_tensor, n_way, dim=2) - output_tensor = input_tensor * filter_tensor - # input_tensor: [n_edge, n_channel, n_dim, n_dim, ..., 1] - # filter_tensor: [n_edge, n_channel, 1, 1, ..., n_dim] - # with (in_way + r_way - coupling_way) dim after n_channel - # We should sum up (coupling_way) n_dim - if coupling_way > 0: - sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)] - output_tensor = torch.sum(output_tensor, dim=sum_axis) + output_tensor = _aggregate(moment_tensor, fn, input_tensor, in_way, r_way, out_way) output_tensor = _scatter_add(output_tensor, idx_i, dim_size=n_atoms) / self.norm_factor - # output_tensor = segment_coo(output_tensor, idx_i, dim_size=batch_data.num_nodes, reduce="sum") - if out_way not in output_tensors: output_tensors[out_way] = output_tensor else: @@ -193,7 +166,7 @@ class TensorAggregateLayer(nn.Module): class SelfInteractionLayer(nn.Module): - def __init__(self, + def __init__(self, input_dim : int, max_in_way : int, output_dim : int=10, @@ -219,7 +192,7 @@ class SelfInteractionLayer(nn.Module): # TODO: cat different way together and use Linear layer to got factor of every channel class NonLinearLayer(nn.Module): - def __init__(self, + def __init__(self, max_in_way : int, input_dim : int, activate_fn : str='jilu', @@ -254,20 +227,20 @@ class SOnEquivalentLayer(nn.Module): self.tensor_aggregate = TensorAggregateLayer(radial_fn=radial_fn, n_channel=input_dim, max_in_way=max_in_way, - max_out_way=max_out_way, + max_out_way=max_out_way, max_r_way=max_r_way, norm_factor=norm_factor,) elif mode == 'simple': self.tensor_aggregate = SimpleTensorAggregateLayer(radial_fn=radial_fn, n_channel=input_dim, max_in_way=max_in_way, - max_out_way=max_out_way, + max_out_way=max_out_way, max_r_way=max_r_way, norm_factor=norm_factor,) # input for SelfInteractionLayer and NonLinearLayer is the output of TensorAggregateLayer # so the max_in_way should equal to max_out_way of TensorAggregateLayer - self.self_interact = SelfInteractionLayer(input_dim=input_dim, - max_in_way=max_out_way, + self.self_interact = SelfInteractionLayer(input_dim=input_dim, + max_in_way=max_out_way, output_dim=output_dim) self.non_linear = NonLinearLayer(activate_fn=activate_fn, max_in_way=max_out_way, @@ -285,7 +258,7 @@ class SOnEquivalentLayer(nn.Module): input_tensors : Dict[int, torch.Tensor], batch_data : Dict[str, torch.Tensor], ) -> Dict[int, torch.Tensor]: - output_tensors = self.tensor_aggregate(input_tensors=input_tensors, + output_tensors = self.tensor_aggregate(input_tensors=input_tensors, batch_data=batch_data) # resnet for r_way in input_tensors.keys(): diff --git a/tensornet/layer/radial.py b/tensornet/layer/radial.py index b63c6281e986d208aa6d38cc6ba6b8c273b252b3..327d9fe34cd3d89b68bb6713007410a296a85e11 100644 --- a/tensornet/layer/radial.py +++ b/tensornet/layer/radial.py @@ -12,7 +12,7 @@ __all__ = ["ChebyshevPoly", class ChebyshevPoly(RadialLayer): - def __init__(self, + def __init__(self, r_max : float, r_min : float=0.5, n_max : int=8, diff --git a/tensornet/layer/readout.py b/tensornet/layer/readout.py index 6c48a0487cb8219d8fe0aa5d4cadf47487453160..c40af4b39cc604371df16badeb5519398c020a21 100644 --- a/tensornet/layer/readout.py +++ b/tensornet/layer/readout.py @@ -10,7 +10,7 @@ __all__ = ["ReadoutLayer"] class ReadoutLayer(nn.Module): - def __init__(self, + def __init__(self, n_dim : int, target_way : Dict[str, int]={"site_energy": 0}, activate_fn : str="jilu", @@ -26,7 +26,7 @@ class ReadoutLayer(nn.Module): for prop, way in target_way.items() }) - def forward(self, + def forward(self, input_tensors : Dict[int, torch.Tensor], ) -> Dict[str, torch.Tensor]: output_tensors = torch.jit.annotate(Dict[str, torch.Tensor], {}) diff --git a/tensornet/loss.py b/tensornet/loss.py index 0e6045f9bb948c422cedb282f6ca12a690ec2138..3fef1f103df483ec4cdb7ec4e005c11ef437dc17 100644 --- a/tensornet/loss.py +++ b/tensornet/loss.py @@ -7,7 +7,7 @@ from .utils import expand_to class Loss: atom_prop = ["forces"] - structure_prop = ["energy", "virial", "dipole"] + structure_prop = ["energy", "virial", "dipole", "polarizability"] def __init__(self, weight : Dict[str, float]={"energy": 1.0, "forces": 1.0}, @@ -16,7 +16,7 @@ class Loss: self.weight = weight self.loss_fn = loss_fn - def get_loss(self, + def get_loss(self, batch_data : Dict[str, torch.Tensor], verbose : bool=False): loss = {} @@ -42,12 +42,12 @@ class Loss: prop : str, ) -> torch.Tensor: n_atoms = expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape)) - return self.loss_fn(batch_data[f'{prop}_p'] / n_atoms, + return self.loss_fn(batch_data[f'{prop}_p'] / n_atoms, batch_data[f'{prop}_t'] / n_atoms) class ForceScaledLoss(Loss): - + def __init__(self, weight : Dict[str, float]={"energy": 1.0, "forces": 1.0}, loss_fn : Callable=F.mse_loss, @@ -72,22 +72,27 @@ class MissingValueLoss(Loss): batch_data : Dict[str, torch.Tensor], prop : str, ) -> torch.Tensor: - idx = batch_data[f'has_{prop}'][batch_data['batch']] - if not torch.any(idx): - return torch.tensor(0.) - if torch.all(idx): - return super().atom_prop_loss(batch_data, prop) - return self.loss_fn(batch_data[f'{prop}_p'][idx], batch_data[f'{prop}_t'][idx]) + # idx = batch_data[f'{prop}_weight'][batch_data['batch']] + # if not torch.any(idx): + # return torch.tensor(0.) + # if torch.all(idx): + # return super().atom_prop_loss(batch_data, prop) + # return self.loss_fn(batch_data[f'{prop}_p'][idx], batch_data[f'{prop}_t'][idx]) + return self.loss_fn(batch_data[f'{prop}_p'] * batch_data[f'{prop}_weight'], + batch_data[f'{prop}_t'] * batch_data[f'{prop}_weight']) def structure_prop_loss(self, batch_data : Dict[str, torch.Tensor], prop : str, ) -> torch.Tensor: - idx = batch_data[f'has_{prop}'] - if not torch.any(idx): - return torch.tensor(0.) - if torch.all(idx): - return super().structure_prop_loss(batch_data, prop) - n_atoms = expand_to(batch_data['n_atoms'][idx], len(batch_data[f'{prop}_p'].shape)) - return self.loss_fn(batch_data[f'{prop}_p'][idx] / n_atoms, - batch_data[f'{prop}_t'][idx] / n_atoms) + # idx = batch_data[f'{prop}_weight'] + # if not torch.any(idx): + # return torch.tensor(0.) + # if torch.all(idx): + # return super().structure_prop_loss(batch_data, prop) + # n_atoms = expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape)) + # return self.loss_fn(batch_data[f'{prop}_p'][idx] / n_atoms, + # batch_data[f'{prop}_t'][idx] / n_atoms) + weight = batch_data[f'{prop}_weight'] / expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape)) + return self.loss_fn(batch_data[f'{prop}_p'] * weight, + batch_data[f'{prop}_t'] * weight) diff --git a/tensornet/model/__init__.py b/tensornet/model/__init__.py index 88b1d019cb1b3a729266f2f0e492ddbcb5b30f46..f9bd3344187480359acf64d3a1b3106736713d29 100644 --- a/tensornet/model/__init__.py +++ b/tensornet/model/__init__.py @@ -1,2 +1,3 @@ from .ani import ANI -from .miao import MiaoNet \ No newline at end of file +from .miao import MiaoNet +from .model_interface import LitAtomicModule \ No newline at end of file diff --git a/tensornet/model/base.py b/tensornet/model/base.py index 8ac5defe81ec9c96acdb7816af710d584a962602..c477d4ae0085a632f2e18e380ffb5c06052dc832 100644 --- a/tensornet/model/base.py +++ b/tensornet/model/base.py @@ -6,7 +6,7 @@ from tensornet.utils import _scatter_add, find_distances, add_scaling class AtomicModule(nn.Module): - def __init__(self, + def __init__(self, mean : float=0., std : float=1., ) -> None: @@ -14,7 +14,7 @@ class AtomicModule(nn.Module): self.register_buffer("mean", torch.tensor(mean).float()) self.register_buffer("std", torch.tensor(std).float()) - def forward(self, + def forward(self, batch_data : Dict[str, torch.Tensor], properties : Optional[List[str]]=None, create_graph : bool=True, @@ -42,6 +42,13 @@ class AtomicModule(nn.Module): ####################################### if 'dipole' in output_tensors: batch_data['dipole_p'] = _scatter_add(output_tensors['dipole'], batch_data['batch']) + if 'polar_diagonal' in output_tensors: + polar_diagonal = _scatter_add(output_tensors['polar_diagonal'], batch_data['batch']) + polar_off_diagonal = _scatter_add(output_tensors['polar_off_diagonal'], batch_data['batch']) + # polar_off_diagonal = polar_off_diagonal + polar_off_diagonal.transpose(1, 2) + shape = polar_off_diagonal.shape + batch_data['polarizability_p'] = polar_diagonal.unsqueeze(-1).unsqueeze(-1) * torch.eye(shape[1], device=polar_diagonal.device).expand(*shape) + polar_off_diagonal + if 'direct_forces' in output_tensors: batch_data['forces_p'] = output_tensors['direct_forces'] * self.std if ('site_energy' in properties) or ('energies' in properties): diff --git a/tensornet/model/miao.py b/tensornet/model/miao.py index 4ffb43561dbd681a0f2633811e91fed601cac007..86f9c82c1f8665766c5bc73cd157255fe20012b7 100644 --- a/tensornet/model/miao.py +++ b/tensornet/model/miao.py @@ -26,6 +26,7 @@ class MiaoNet(AtomicModule): mode : str='normal', ): super().__init__(mean=mean, std=std) + self.register_buffer("norm_factor", torch.tensor(norm_factor).float()) self.embedding_layer = embedding_layer max_r_way = expand_para(max_r_way, n_layers) max_out_way = expand_para(max_out_way, n_layers) diff --git a/tensornet/model/model_interface.py b/tensornet/model/model_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..985b5744ac2a42c20442536da882489ec87d4343 --- /dev/null +++ b/tensornet/model/model_interface.py @@ -0,0 +1,171 @@ +from typing import Optional, Dict, List, Type, Any +from .base import AtomicModule +from ..loss import Loss, MissingValueLoss, ForceScaledLoss +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + + +class LitAtomicModule(pl.LightningModule): + + def __init__(self, + model: AtomicModule, + p_dict: Dict, + ): + super().__init__() + self.p_dict = p_dict + self.model = model + self.loss_calculator = self.get_loss_calculator() + + grad_prop = set(['forces', 'virial', 'stress']) + self.required_derivatives = len(grad_prop.intersection(self.p_dict["Train"]['targetProp'])) > 0 + # self.save_hyperparameters(ignore=['model']) + + def forward(self, + batch_data : Dict[str, torch.Tensor], + properties : Optional[List[str]]=None, + create_graph : bool=True, + ) -> Dict[str, torch.Tensor]: + results = self.model(batch_data, properties, create_graph) + return results + + def get_loss_calculator(self): + train_dict = self.p_dict['Train'] + target = train_dict['targetProp'] + weight = train_dict['weight'] + weights = {p: w for p, w in zip(target, weight)} + if "direct_forces" in weights: + weights["forces"] = weights.pop("direct_forces") # direct forces use the same key of forces + if train_dict['allowMissing']: + # TODO: rewrite loss function + if train_dict['forceScale'] > 0: + raise Exception("Now forceScale not support allowMissing!") + return MissingValueLoss(weights, loss_fn=F.mse_loss) + else: + if train_dict['forceScale'] > 0: + return ForceScaledLoss(weights, loss_fn=F.mse_loss, scaled=train_dict['forceScale']) + else: + return Loss(weights, loss_fn=F.mse_loss) + + def training_step(self, batch, batch_idx): + self.model(batch, self.p_dict["Train"]['targetProp']) + loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True) + self.log("train_loss", loss) + for prop in loss_dict: + self.log(f'train_{prop}', loss_dict[prop]) + return loss + + def validation_step(self, batch, batch_idx): + torch.set_grad_enabled(self.required_derivatives) + self.model(batch, self.p_dict["Train"]['targetProp'], create_graph=False) + loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True) + self.log("val_loss", loss, batch_size=batch['n_atoms'].shape[0]) + for prop in loss_dict: + self.log(f'val_{prop}', loss_dict[prop], batch_size=batch['n_atoms'].shape[0]) + + def test_step(self, batch, batch_idx): + torch.set_grad_enabled(self.required_derivatives) + self.model(batch, self.p_dict["Train"]['targetProp'], create_graph=False) + loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True) + loss_dict['test_loss'] = loss + self.log_dict(loss_dict) + return loss_dict + + def get_optimizer(self): + opt_dict = self.p_dict["Train"]["Optimizer"] + decay_interactions = {} + no_decay_interactions = {} + for name, param in self.model.son_equivalent_layers.named_parameters(): + if "weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": self.model.embedding_layer.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": opt_dict["weightDecay"], + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "readouts", + "params": self.model.readout_layer.parameters(), + "weight_decay": 0.0, + }, + ], + lr=self.p_dict["Train"]["learningRate"], + amsgrad=opt_dict["amsGrad"], + ) + + if opt_dict['type'] == "Adam": + return torch.optim.Adam(**param_options) + elif opt_dict['type'] == "AdamW": + return torch.optim.AdamW(**param_options) + else: + raise Exception("Unsupported optimizer: {}!".format(opt_dict["type"])) + + def get_lr_scheduler(self, optimizer): + lr_dict = self.p_dict["Train"]["LrScheduler"] + if lr_dict['type'] == "exponential": + return torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, + gamma=lr_dict['gamma'] + ) + elif lr_dict['type'] == "reduceOnPlateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + factor=lr_dict['lrFactor'], + patience=lr_dict['patience'] + ) + if self.p_dict["Train"]["evalEpochInterval"] == 1: + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "step", + "monitor": "val_loss", + "frequency": self.p_dict["Train"]["evalStepInterval"], + } + else: + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "epoch", + "monitor": "val_loss", + "frequency": self.p_dict["Train"]["evalEpochInterval"], + } + return lr_scheduler_config + elif lr_dict['type'] == "constant": + return None + else: + raise Exception("Unsupported LrScheduler: {}!".format(lr_dict['type'])) + + def configure_optimizers(self): + optimizer = self.get_optimizer() + scheduler = self.get_lr_scheduler(optimizer) + if scheduler: + return [optimizer], [scheduler] + else: + return [optimizer] + + # Learning rate warm-up + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + # update params + optimizer.step(closure=optimizer_closure) + + # manually warm up lr without a scheduler + if self.trainer.global_step < self.p_dict["Train"]["warmupSteps"]: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.p_dict["Train"]["warmupSteps"]) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.p_dict["Train"]["learningRate"] + + # def lr_scheduler_step(self, scheduler, metric): + # scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value diff --git a/tensornet/utils.py b/tensornet/utils.py index 377d2a20669637e9a28d3d9d1c1f40be72759e21..15d57cfa7cc577008d97ce4d2715d8c8ff3ba95c 100644 --- a/tensornet/utils.py +++ b/tensornet/utils.py @@ -13,14 +13,14 @@ def setup_seed(seed): torch.backends.cudnn.deterministic = True -def expand_to(t : torch.Tensor, - n_dim : int, +def expand_to(t : torch.Tensor, + n_dim : int, dim : int=-1) -> torch.Tensor: """Expand dimension of the input tensor t at location 'dim' until the total dimention arrive 'n_dim' Args: t (torch.Tensor): Tensor to expand - n_dim (int): target dimension + n_dim (int): target dimension dim (int, optional): location to insert axis. Defaults to -1. Returns: @@ -31,7 +31,7 @@ def expand_to(t : torch.Tensor, return t -def multi_outer_product(v: torch.Tensor, +def multi_outer_product(v: torch.Tensor, n: int) -> torch.Tensor: """Calculate 'n' times outer product of vector 'v' @@ -51,7 +51,7 @@ def multi_outer_product(v: torch.Tensor, def add_scaling(batch_data : Dict[str, torch.Tensor],) -> Dict[str, torch.Tensor]: if 'has_add_scaling' not in batch_data: idx_m = batch_data['batch'] - idx_i = batch_data['edge_index'][0] + idx_i = batch_data['idx_i'] batch_data['coordinate'] = torch.matmul(batch_data['coordinate'][:, None, :], batch_data['scaling'][idx_m]).squeeze(1) if 'offset' in batch_data: @@ -63,11 +63,11 @@ def add_scaling(batch_data : Dict[str, torch.Tensor],) -> Dict[str, torch.Tenso def find_distances(batch_data : Dict[str, torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if 'rij' not in batch_data: - idx_i = batch_data["edge_index"][0] + idx_i = batch_data["idx_i"] if 'ghost_neigh' in batch_data: # neighbor for lammps calculation idx_j = batch_data["ghost_neigh"] else: - idx_j = batch_data["edge_index"][1] + idx_j = batch_data["idx_j"] if 'offset' in batch_data: batch_data['rij'] = batch_data['coordinate'][idx_j] + batch_data['offset'] - batch_data['coordinate'][idx_i] else: @@ -170,3 +170,27 @@ def progress_bar(i: int, n: int, interval: int=100): if i % interval == 0: ii = int(i / n * 100) print(f"\r{ii}%[{'*' * ii}{'-' * (100 - ii)}]", end=' ', file=sys.stderr) + + +@torch.jit.script +def _aggregate(moment_tensor: torch.Tensor, + fn : torch.Tensor, + input_tensor:torch.Tensor, + in_way : int, + r_way : int, + out_way: int + ) -> torch.Tensor: + filter_tensor = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...] + coupling_way = (in_way + r_way - out_way) // 2 + n_way = in_way + r_way - coupling_way + 2 + input_tensor = expand_to(input_tensor, n_way, dim=-1) + filter_tensor = expand_to(filter_tensor, n_way, dim=2) + output_tensor = input_tensor * filter_tensor + # input_tensor: [n_edge, n_channel, n_dim, n_dim, ..., 1] + # filter_tensor: [n_edge, n_channel, 1, 1, ..., n_dim] + # with (in_way + r_way - coupling_way) dim after n_channel + # We should sum up (coupling_way) n_dim + if coupling_way > 0: + sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)] + output_tensor = torch.sum(output_tensor, dim=sum_axis) + return output_tensor diff --git a/tools/convert.py b/tools/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..197e5ce69645fcbcac9837062db25ae67808a8c6 --- /dev/null +++ b/tools/convert.py @@ -0,0 +1,49 @@ +import torch +from train import get_model +import yaml +import sys + + +def convert(ckptfile, device='cpu'): + ckpt = torch.load(ckptfile, map_location=device) + state_dict = ckpt['state_dict'] + elements = state_dict['model.all_elements'].detach().cpu().numpy() + mean = state_dict['model.mean'].detach().cpu().numpy() + std = state_dict['model.std'].detach().cpu().numpy() + n_neighbor = state_dict['model.norm_factor'].detach().cpu().numpy() + cutoff = state_dict['model.cutoff'].detach().cpu().numpy() + + p_dict = { + "cutoff": cutoff, + "Model": { + "mode": "normal", + "activateFn": "silu", + "nEmbedding": 64, + "nLayer": 5, + "maxRWay": 2, + "maxOutWay": 2, + "nHidden": 64, + "targetWay": {0 : 'site_energy'}, + "CutoffLayer": { + "type": "poly", + "p": 5, + }, + "RadialLayer": { + "type": "besselMLP", + "nBasis": 8, + "nHidden": [64, 64, 64], + "activateFn": "silu", + } + }, + } + + with open('input.yaml') as f: + p_dict['Model'].update(yaml.load(f, Loader=yaml.FullLoader)) + model = get_model(p_dict, elements=elements, mean=mean, std=std, n_neighbor=n_neighbor) + model.load_state_dict(state_dict) + torch.save('model.pt', model) + + +if __name__ == "__main__": + ckptfile = sys.argv[1] + convert(ckptfile) diff --git a/tools/eval.py b/tools/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e81cbafe37eee543a9f8fb54de613e7796fedbcb --- /dev/null +++ b/tools/eval.py @@ -0,0 +1,49 @@ +from ase.io import read +import numpy as np +import torch +from torch_geometric.loader import DataLoader +from tensornet.data import ASEData, ASEDBData, atoms_collate_fn + + +def eval(model, data_loader, properties, device='cpu'): + output = {prop: [] for prop in properties} + target = {prop: [] for prop in properties} + n_atoms = [] + for batch_data in data_loader: + for key in batch_data: + batch_data[key] = batch_data[key].to(device) + model(batch_data, properties, create_graph=False) + n_atoms.extend(batch_data['n_atoms'].detach().cpu().numpy()) + for prop in properties: + output[prop].extend(batch_data[f'{prop}_p'].detach().cpu().numpy()) + if f'{prop}_t' in batch_data: + target[prop].extend(batch_data[f'{prop}_t'].detach().cpu().numpy()) + for prop in properties: + np.save(f'output_{prop}.npy', np.array(output[prop])) + np.save(f'target_{prop}.npy', np.array(target[prop])) + np.save('n_atoms.npy', np.array(n_atoms)) + return None + + +def main(*args, cutoff=None, modelfile='model.pt', device='cpu', datafile='data.traj', format=None, + properties=["energy", "forces"], batchsize=32, num_workers=4, pin_memory=True, **kwargs): + model = torch.load(model, map_location=device) + cutoff = float(model.cutoff.detach().cpu().numpy()) + if '.db' in dataset: + dataset = ASEDBData(datapath=dataset, + properties=properties, + cutoff=cutoff) + else: + frames = read(dataset, index=':', format=format) + dataset = ASEData(frames=frames, cutoff=cutoff, properties=properties) + data_loader = DataLoader(dataset, + batch_size=batchsize, + shuffle=False, + collate_fn=atoms_collate_fn, + num_workers=num_workers, + pin_memory=pin_memory) + eval(model, data_loader, properties, device) + + +if __name__ == "__main__": + main() diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8543942d9c1f0f29dbf24b4f26afc9c7f5b42bbf --- /dev/null +++ b/tools/train.py @@ -0,0 +1,271 @@ +import logging, time, yaml, os +import numpy as np +import torch +from torch import nn +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +import torch.nn.functional as F +from torch.optim.swa_utils import AveragedModel +from ase.data import atomic_numbers +from tensornet.utils import setup_seed +from tensornet.model import MiaoNet, LitAtomicModule +from tensornet.layer.cutoff import * +from tensornet.layer.embedding import AtomicEmbedding +from tensornet.layer.radial import * +from tensornet.data import LitAtomsDataset + +from tensornet.logger import set_logger +set_logger(log_path='log.txt', level='DEBUG') +log = logging.getLogger(__name__) + + +class LogAllLoss(pl.Callback): + + def on_train_epoch_end(self, trainer, pl_module): + if trainer.global_rank == 0: + epoch = trainer.current_epoch + + if epoch == 0: + content = f"{'epoch':^10}|{'lr':^10}|{'total':^21}" + for prop in pl_module.p_dict["Train"]['targetProp']: + content += f"|{prop:^21}" + log.info(content) + + if epoch % pl_module.p_dict["Train"]["evalInterval"] == 0: + lr = trainer.optimizers[0].param_groups[0]["lr"] + loss_metrics = trainer.callback_metrics + train_loss = loss_metrics['train_loss'] + val_loss = loss_metrics['val_loss'] + content = f"{epoch:^10}|{lr:^10.2e}|{train_loss:^10.4f}/{val_loss:^10.4f}" + for prop in pl_module.p_dict["Train"]['targetProp']: + prop = "forces" if prop == "direct_forces" else prop + content += f"|{loss_metrics[f'train_{prop}']:^10.4f}/{loss_metrics[f'val_{prop}']:^10.4f}" + log.info(content) + +def update_dict(d1, d2): + for key in d2: + if key in d1 and isinstance(d1[key], dict): + update_dict(d1[key], d2[key]) + else: + d1[key] = d2[key] + return d1 + + +def get_cutoff(p_dict): + cutoff = p_dict['cutoff'] + cut_dict = p_dict['Model']['CutoffLayer'] + if cut_dict['type'] == "cos": + return CosineCutoff(cutoff=cutoff) + elif cut_dict['type'] == "cos2": + return SmoothCosineCutoff(cutoff=cutoff, cutoff_smooth=cut_dict['smoothCutoff']) + elif cut_dict['type'] == "poly": + return PolynomialCutoff(cutoff=cutoff, p=cut_dict['p']) + else: + raise Exception("Unsupported cutoff type: {}, please choose from cos, cos2, and poly!".format(cut_dict['type'])) + + +def get_radial(p_dict, cutoff_fn): + cutoff = p_dict['cutoff'] + radial_dict = p_dict['Model']['RadialLayer'] + if "bessel" in radial_dict['type']: + radial_fn = BesselPoly(r_max=cutoff, n_max=radial_dict['nBasis'], cutoff_fn=cutoff_fn) + elif "chebyshev" in radial_dict['type']: + if "minDist" in radial_dict: + r_min = radial_dict['minDist'] + else: + r_min = 0.5 + log.warning("You are using chebyshev poly as basis function, but does not given 'minDist', " + "this may cause some problems!") + radial_fn = ChebyshevPoly(r_max=cutoff, r_min=r_min, n_max=radial_dict['nBasis'], cutoff_fn=cutoff_fn) + else: + raise Exception("Unsupported radial type: {}!".format(radial_dict['type'])) + if "MLP" in radial_dict['type']: + if radial_dict["activateFn"] == "silu": + activate_fn = nn.SiLU() + elif radial_dict["activateFn"] == "relu": + activate_fn = nn.ReLU() + else: + raise Exception("Unsupported activate function in radial type: {}!".format(radial_dict["activateFn"])) + return MLPPoly(n_hidden=radial_dict['nHidden'], radial_fn=radial_fn, activate_fn=activate_fn) + else: + return radial_fn + + +def get_model(p_dict, elements, mean, std, n_neighbor): + model_dict = p_dict['Model'] + target = p_dict['Train']['targetProp'] + target_way = {} + if ("energy" in target) or ("forces" in target) or ("virial" in target): + target_way["site_energy"] = 0 + if "dipole" in target: + target_way["dipole"] = 1 + if "direct_forces" in target: + assert "forces" not in target_way, "Cannot learn forces and direct_forces at the same time" + target_way["direct_forces"] = 1 + cut_fn = get_cutoff(p_dict) + emb = AtomicEmbedding(elements, model_dict['nEmbedding']) # only support atomic embedding now + radial_fn = get_radial(p_dict, cut_fn) + model = MiaoNet(embedding_layer=emb, + radial_fn=radial_fn, + n_layers=model_dict['nLayer'], + max_r_way=model_dict['maxRWay'], + max_out_way=model_dict['maxOutWay'], + output_dim=model_dict['nHidden'], + activate_fn=model_dict['activateFn'], + target_way=target_way, + mean=mean, + std=std, + norm_factor=n_neighbor, + mode=model_dict['mode']) + return model + + +def save_checkpoints(path, name, model, ema, optimizer, lr_scheduler): + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + } + if ema is not None: + checkpoint["ema"] = ema.state_dict() + torch.save(checkpoint, os.path.join(path, f"state_dict-{name}.pt")) + torch.save(model, os.path.join(path, f"model-{name}.pt")) + +def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, **kwargs): + # Default values + p_dict = { + "workDir": os.getcwd(), + "seed": np.random.randint(0, 100000000), + "device": "cuda" if torch.cuda.is_available() else "cpu", + "outputDir": os.path.join(os.getcwd(), "outDir"), + "Data": { + "path": os.getcwd(), + "trainBatch": 32, + "testBatch": 32, + "std": "force", + "numWorkers": 0, + "pinMemory": False, + }, + "Model": { + "mode": "normal", + "activateFn": "silu", + "nEmbedding": 64, + "nLayer": 5, + "maxRWay": 2, + "maxOutWay": 2, + "nHidden": 64, + "targetWay": {0 : 'site_energy'}, + "CutoffLayer": { + "type": "poly", + "p": 5, + }, + "RadialLayer": { + "type": "besselMLP", + "nBasis": 8, + "nHidden": [64, 64, 64], + "activateFn": "silu", + } + }, + "Train": { + "learningRate": 0.001, + "allowMissing": False, + "targetProp": ["energy", "forces"], + "weight": [0.1, 1.0], + "forceScale": 0., + "evalInterval": 10, + "saveInterval": 500, + "saveStart": 1000, + "evalTest": True, + "maxGradNorm": None, + "Optimizer": { + "type": "Adam", + "amsGrad": True, + "weightDecay": 0., + }, + "LrScheduler": { + "type": "constant", + }, + "emaDecay": 0., + }, + } + with open(input_file) as f: + update_dict(p_dict, yaml.load(f, Loader=yaml.FullLoader)) + + os.makedirs(p_dict["outputDir"], exist_ok=True) + + with open("allpara.yaml", "w") as f: + yaml.dump(p_dict, f) + + setup_seed(p_dict["seed"]) + log.info("Using seed {}".format(p_dict["seed"])) + + log.info(f"Preparing data...") + dataset = LitAtomsDataset(p_dict) + dataset.setup() + + try: + mean = dataset.per_energy_mean.detach().cpu().numpy() + except: + mean = 0. + if p_dict["Data"]["std"] == "force": + std = dataset.forces_std.detach().cpu().numpy() + elif p_dict["Data"]["std"] == "energy": + std = dataset.per_energy_std.detach().cpu().numpy() + else: + assert type(std) is float, "std must be 'force', 'energy' or a float!" + std = p_dict["Data"]["std"] + n_neighbor = dataset.n_neighbor_mean.detach().cpu().numpy() + elements = list(dataset.all_elements.detach().cpu().numpy()) + log.info(f"mean : {mean}") + log.info(f"std : {std}") + log.info(f"n_neighbor : {n_neighbor}") + log.info(f"all_elements : {elements}") + if load_model is not None: + model = torch.load(load_model) + else: + model = get_model(p_dict, elements, mean, std, n_neighbor) + model.register_buffer('all_elements', torch.tensor(elements, dtype=torch.long)) + model.register_buffer('cutoff', torch.tensor(p_dict["cutoff"], dtype=torch.float64)) + + # log.info(" Network Architecture ".center(100, "=")) + # log.info(model) + # log.info(f"Number of parameters: {sum([p.numel() for p in model.parameters()])}") + # log.info("=" * 100) + + lit_model = LitAtomicModule(model=model, p_dict=p_dict) + from lightning.pytorch.profilers import PyTorchProfiler + profiler = PyTorchProfiler( + on_trace_ready=torch.profiler.tensorboard_trace_handler('.'), + schedule=torch.profiler.schedule(skip_first=10, + wait=1, + warmup=1, + active=20, + repeat=1)) + + logger = pl.loggers.TensorBoardLogger(save_dir='.') + callbacks = [ + ModelCheckpoint( + dirpath='outDir', + filename='{epoch}-{val_loss:.2f}', + every_n_epochs=p_dict["Train"]["evalInterval"], + save_top_k=1, + monitor="val_loss" + ), + LogAllLoss(), + ] + trainer = pl.Trainer( + profiler=profiler, + logger=logger, + callbacks=callbacks, + default_root_dir='.', + max_epochs=100, + enable_progress_bar=False, + log_every_n_steps=50, + #num_nodes=1, + strategy='ddp_find_unused_parameters_true' + ) + + trainer.fit(lit_model, datamodule=dataset) + +if __name__ == "__main__": + main()