diff --git a/README.md b/README.md
index a052f83e559d15c353522f90f471e3e25c274bc1..5411221311771d8983eb2f17c1743bba037ff839 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,7 @@
 # Tensor Passing Network
+## Install
+
+```shell
+$ pip install git+ssh://git@git.nju.edu.cn/bigd4/tpn.git
+```
+## Lammps
diff --git a/interface/ase/ase.py b/interface/ase/ase.py
index 92f2599fae15734551716f9ea3bd84c5900f4138..847b49fffbec132051e142becfea6c5b4e1cb35e 100644
--- a/interface/ase/ase.py
+++ b/interface/ase/ase.py
@@ -11,18 +11,19 @@ class MiaoCalculator(Calculator):
         "energies",
         "forces",
         "stress",
+        "dipole",
+        "polarizability",
     ]
 
-    def __init__(self, 
-                 cutoff     : float,
+    def __init__(self,
                  model_file : str="model.pt",
                  device     : str="cpu",
                  **kwargs,
                  ) -> None:
         Calculator.__init__(self, **kwargs)
-        self.cutoff = cutoff
         self.device = device
-        self.model = torch.load(model_file).to(device).double()
+        self.model = torch.load(model_file, map_location=device).double()
+        self.cutoff = float(self.model.cutoff.detach().cpu().numpy())
 
     def calculate(
         self,
@@ -35,22 +36,17 @@ class MiaoCalculator(Calculator):
         Calculator.calculate(self, atoms, properties, system_changes)
 
         idx_i, idx_j, offsets = neighbor_list("ijS", atoms, self.cutoff, self_interaction=False)
-        bonds = np.array([idx_i, idx_j])
-        offsets = np.array(offsets) @ atoms.get_cell()
-        atomic_number = torch.tensor(atoms.numbers, dtype=torch.long, device=self.device)
-        edge_index = torch.tensor(bonds, dtype=torch.long, device=self.device)
-        offset = torch.tensor(offsets, dtype=torch.double, device=self.device)
-        coordinate = torch.tensor(atoms.positions, dtype=torch.double, device=self.device)
-        n_atoms = torch.tensor(len(atoms), dtype=torch.double, device=self.device)
-        batch = torch.zeros(len(atoms), dtype=torch.long, device=self.device)
-        
+        offset = np.array(offsets) @ atoms.get_cell()
+
         data = {
-            "atomic_number" : atomic_number,
-            "edge_index"    : edge_index,
-            "offset"        : offset,
-            "coordinate"    : coordinate,
-            "n_atoms"       : n_atoms,
-            "batch"         : batch,
+            "atomic_number": torch.tensor(atoms.numbers, dtype=torch.long, device=self.device),
+            "idx_i"        : torch.tensor(idx_i, dtype=torch.long, device=self.device),
+            "idx_j"        : torch.tensor(idx_j, dtype=torch.long, device=self.device),
+            "coordinate"   : torch.tensor(atoms.positions, dtype=torch.double, device=self.device),
+            "n_atoms"      : torch.tensor([len(atoms)], dtype=torch.long, device=self.device),
+            "offset"       : torch.tensor(offset, dtype=torch.double, device=self.device),
+            "scaling"      : torch.eye(3, dtype=torch.double, device=self.device).view(1, 3, 3),
+            "batch"        : torch.zeros(len(atoms), dtype=torch.long, device=self.device),
         }
 
         self.model(data, properties, create_graph=False)
@@ -61,12 +57,13 @@ class MiaoCalculator(Calculator):
         if "forces" in properties:
             self.results["forces"] = data["forces_p"].detach().cpu().numpy()
         if "stress" in properties:
-            raise Exception("ni shi gu yi zhao cha shi bu shi?")
-            # virial = np.sum(
-            #     np.array(self.calc.getVirials()).reshape(9, -1), axis=1)
-            # if sum(atoms.get_pbc()) > 0:
-            #     stress = -0.5 * (virial.copy() +
-            #                      virial.copy().T) / atoms.get_volume()
-            #     self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]]
-            # else:
-            #     raise PropertyNotImplementedError
+            virial = data["virial_p"].detach().cpu().numpy().reshape(-1)
+            if sum(atoms.get_pbc()) > 0:
+                stress = 0.5 * (virial.copy() + virial.copy().T) / atoms.get_volume()
+                self.results['stress'] = stress.flat[[0, 4, 8, 5, 2, 1]]
+            else:
+                raise PropertyNotImplementedError
+        if "dipole" in properties:
+            self.results["dipole"] = data["dipole_p"].detach().cpu().numpy()
+        if "polarizability" in properties:
+            self.results["polarizability"] = data["polarizability_p"].detach().cpu().numpy()
diff --git a/interface/lammps/cmake/CMakeLists.txt b/interface/lammps/cmake/CMakeLists.txt
index 61abcf6f2ce4e77489a9f6a56fbd094dcd853760..0b2f139b0ddafed7b59031b226b2127737c94089 100644
--- a/interface/lammps/cmake/CMakeLists.txt
+++ b/interface/lammps/cmake/CMakeLists.txt
@@ -50,6 +50,7 @@ endif()
 include(LAMMPSUtils)
 
 get_lammps_version(${LAMMPS_SOURCE_DIR}/version.h LAMMPS_VERSION_NUMBER)
+add_compile_definitions(LAMMPS_VERSION_NUMBER=${LAMMPS_VERSION_NUMBER})
 
 include(PreventInSourceBuilds)
 
diff --git a/interface/lammps/src/.pair_miao.cpp.swp b/interface/lammps/src/.pair_miao.cpp.swp
deleted file mode 100644
index 7c9fda6926489cbfa8b95ae2296e8a11da72299b..0000000000000000000000000000000000000000
Binary files a/interface/lammps/src/.pair_miao.cpp.swp and /dev/null differ
diff --git a/interface/lammps/src/pair_miao.cpp b/interface/lammps/src/pair_miao.cpp
index 7282572ab50badd59e04d27d408754e366dde920..f66d0938ec0a1b5a41ef4b604ebde2d58d81ef45 100644
--- a/interface/lammps/src/pair_miao.cpp
+++ b/interface/lammps/src/pair_miao.cpp
@@ -165,8 +165,9 @@ void PairMIAO::compute(int eflag, int vflag)
     int numneigh_atom = accumulate(numneigh, numneigh + inum , 0);
 
     std::vector<double> cart(nall * 3);
-    std::vector<long> atom_index(numneigh_atom * 2);
-    std::vector<long> ghost_neigh(numneigh_atom);
+    std::vector<long> idx_i(numneigh_atom);            // atom i
+    std::vector<long> idx_j(numneigh_atom);            // atom j
+    std::vector<long> ghost_neigh(numneigh_atom);      // ghost atom j for calculate distance
     std::vector<long> local_species(inum);
     std::vector<long> batch(nall);
     double dx, dy, dz, d2;
@@ -192,8 +193,8 @@ void PairMIAO::compute(int eflag, int vflag)
             d2 = dx * dx + dy * dy + dz * dz;
             if (d2 < cutoffsq)
             {
-                atom_index[totneigh * 2] = i;
-                atom_index[totneigh * 2 + 1] = atom->map(tag[j]);
+                idx_i[totneigh] = i;
+                idx_j[totneigh] = atom->map(tag[j]);
                 ghost_neigh[totneigh] = j;
                 ++totneigh;
             }
@@ -233,17 +234,19 @@ void PairMIAO::compute(int eflag, int vflag)
     species_file.close();
     */
     auto cart_ = torch::from_blob(cart.data(), {nall, 3}, option1).to(device, true).to(tensor_type);
-    auto atom_index_ = torch::from_blob(atom_index.data(), {totneigh, 2}, option2).transpose(1, 0).to(device, true);
+    auto idx_i_ = torch::from_blob(idx_i.data(), {totneigh}, option2).to(device, true);
+    auto idx_j_ = torch::from_blob(idx_j.data(), {totneigh}, option2).to(device, true);
     auto ghost_neigh_ = torch::from_blob(ghost_neigh.data(), {totneigh}, option2).to(device, true);
     auto local_species_ = torch::from_blob(local_species.data(), {inum}, option2).to(device, true);
     auto batch_ = torch::from_blob(batch.data(), {inum}, option2).to(device, true);
     batch_data.insert("coordinate", cart_);
     batch_data.insert("atomic_number", local_species_);
-    batch_data.insert("edge_index", atom_index_);
+    batch_data.insert("idx_i", idx_i_);
+    batch_data.insert("idx_j", idx_j_);
     batch_data.insert("batch", batch_);
     batch_data.insert("ghost_neigh", ghost_neigh_);
 
-    torch::IValue output = module.forward({batch_data, properties});
+    torch::IValue output = module.forward({batch_data, properties, false});
     results = c10::impl::toTypedDict<std::string, torch::Tensor>(output.toGenericDict());
 
     auto forces_tensor = results.at("forces_p").to(torch::kDouble).cpu().reshape({-1});
diff --git a/setup.py b/setup.py
index b8f27133cb7f3f95df124c0f94e8849db5d97bc3..aca606ac03386191d87c3b63ef1a6d930d344584 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@ setup(
         "ase",
         "pyyaml",
         "torch",
-        "torch-geometric",
+        "lightning",
     ],
     license="MIT",
     description="MiaoNet: Moment tensor InterAggregate Operation Net",
diff --git a/tensornet/data/__init__.py b/tensornet/data/__init__.py
index b26dfd0230b04c60d1319e9a939ddc17d16d8c82..a4faf1c54f074b114b318e062a999e305e6e3528 100644
--- a/tensornet/data/__init__.py
+++ b/tensornet/data/__init__.py
@@ -1,4 +1,5 @@
 from .base import *
-from .md17 import *
-from .asedata import *
-from .ptdata import *
+# from .md17 import *
+from .asedata import ASEData, ASEDBData
+# from .ptdata import *
+from .data_interface import LitAtomsDataset
\ No newline at end of file
diff --git a/tensornet/data/asedata.py b/tensornet/data/asedata.py
index e133e2b22e5d3e557c6bb443538a4afa34f3d28d..591dd60823a262b22624997cbfe0c81239903df8 100644
--- a/tensornet/data/asedata.py
+++ b/tensornet/data/asedata.py
@@ -1,44 +1,70 @@
-from .base import AtomsData
-from ..utils import progress_bar
+from .base import AtomsDataset
 from typing import List, Optional
 from ase import Atoms
 from ase.io import read
-import torch
-import os
+from ase.db import connect
 
 
-class ASEData(AtomsData):
+class ASEData(AtomsDataset):
+
     def __init__(self,
-                 frames : Optional[List[Atoms]]=None,
-                 root   : Optional[str]=None,
-                 name   : Optional[str]=None,
-                 format : Optional[str]=None,
-                 cutoff : float=4.0,
-                 device : str="cpu",
+                 frames     : Optional[List[Atoms]]=None,
+                 indices    : Optional[List[int]]=None,
+                 properties : Optional[List[str]]=['energy', 'forces'],
+                 cutoff     : float=4.0,
                  ) -> None:
-        self.cutoff = cutoff
-        self.device = device
-        if frames is None:
-            frames = read(os.path.join(root, name), format=format, index=':')
+        super().__init__(indices=indices, cutoff=cutoff)
         self.frames = frames
-        self.name = name or "processed"
-        if root is None:
-            root = os.getcwd()
-        super().__init__(root)
-        self.data, self.slices = torch.load(self.processed_paths[0], map_location=device)
-
-    @property
-    def processed_file_names(self) -> str:
-        return f"{self.name}.pt"
-
-    def process(self):
-        n_data = len(self.frames)
-        data_list = []
-        for i in range(n_data):
-            progress_bar(i, n_data)
-            data = self.atoms_to_graph(self.frames[i], self.cutoff, self.device)
-            data_list.append(data)
-        torch.save(self.collate(data_list), self.processed_paths[0])
-
-    def __repr__(self) -> str:
-        return f"{self.__class__.__name__}({len(self)}, name='{self.name}')"
+        self.properties = properties
+
+    def __len__(self):
+        if self.indices is None:
+            return len(self.frames)
+        else:
+            return len(self.indices)
+
+    def __getitem__(self, idx):
+        if self.indices is not None:
+            idx = self.indices[idx]
+        data = self.atoms_to_data(self.frames[idx],
+                                  properties=self.properties,
+                                  cutoff=self.cutoff)
+        return data
+
+    def extend(self, frames):
+        self.frames.extend(frames)
+
+
+class ASEDBData(AtomsDataset):
+
+    def __init__(self,
+                 datapath   : Optional[List[Atoms]]=None,
+                 indices    : Optional[List[int]]=None,
+                 properties : Optional[List[str]]=['energy', 'forces'],
+                 cutoff     : float=4.0,
+                 ) -> None:
+        super().__init__(indices=indices, cutoff=cutoff)
+        self.datapath = datapath
+        self.conn = connect(self.datapath, use_lock_file=False)
+        self.properties = properties
+
+    def __len__(self):
+        if self.indices is None:
+            return self.conn.count()
+        else:
+            return len(self.indices)
+
+    def __getitem__(self, idx):
+        if self.indices is not None:
+            idx = int(self.indices[idx])
+        row = self.conn.get(idx + 1)
+        atoms = Atoms(numbers=row['numbers'],
+                      cell=row['cell'],
+                      positions=row['positions'],
+                      pbc=row['pbc'],
+                      info=row.data
+                      )
+        data = self.atoms_to_data(atoms,
+                                  properties=self.properties,
+                                  cutoff=self.cutoff)
+        return data
diff --git a/tensornet/data/base.py b/tensornet/data/base.py
index 82e9a9a83105d0c942eafb4e2908557483dc3b51..6682c1eb232e7ca160fe91aae8628a567d29ea4e 100644
--- a/tensornet/data/base.py
+++ b/tensornet/data/base.py
@@ -1,87 +1,91 @@
 import torch
-import os
+import copy
+import abc
 import numpy as np
 from tensornet.utils import EnvPara
-from torch_geometric.data import Data, InMemoryDataset
+from torch.utils.data import Dataset
 from ase.neighborlist import neighbor_list
+from typing import Optional, List
 
+# TODO: offset and scaling for different condition
+class AtomsDataset(Dataset, abc.ABC):
 
-class AtomsData(InMemoryDataset):
     @staticmethod
-    def atoms_to_graph(atoms, cutoff, device='cpu'):
+    def atoms_to_data(atoms, cutoff, properties=['energy', 'forces']):
         dim = len(atoms.get_cell())
         idx_i, idx_j, offset = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
-        bonds = np.array([idx_i, idx_j])
         offset = np.array(offset) @ atoms.get_cell()
-        index = torch.arange(len(atoms), dtype=torch.long, device=device)
-        atomic_number = torch.tensor(atoms.numbers, dtype=torch.long, device=device)
-        edge_index = torch.tensor(bonds, dtype=torch.long, device=device)
-        offset = torch.tensor(offset, dtype=EnvPara.FLOAT_PRECISION, device=device)
-        coordinate = torch.tensor(atoms.positions, dtype=EnvPara.FLOAT_PRECISION, device=device)
-        scaling = torch.eye(dim, dtype=EnvPara.FLOAT_PRECISION, device=device).view(1, dim, dim)
-        n_atoms = torch.tensor(len(atoms), dtype=EnvPara.FLOAT_PRECISION, device=device)
-        graph = Data(x=index,
-                    atomic_number=atomic_number, 
-                    edge_index=edge_index,
-                    offset=offset,
-                    coordinate=coordinate,
-                    n_atoms=n_atoms,
-                    scaling=scaling,
-                    )
+
+        data = {
+            "atomic_number": torch.tensor(atoms.numbers, dtype=torch.long),
+            "idx_i": torch.tensor(idx_i, dtype=torch.long),
+            "idx_j": torch.tensor(idx_j, dtype=torch.long),
+            "coordinate": torch.tensor(atoms.positions, dtype=EnvPara.FLOAT_PRECISION),
+            "n_atoms": torch.tensor([len(atoms)], dtype=torch.long),
+            "offset": torch.tensor(offset, dtype=EnvPara.FLOAT_PRECISION),
+            "scaling": torch.eye(dim, dtype=EnvPara.FLOAT_PRECISION).view(1, dim, dim)
+        }
+
         padding_shape = {
             'site_energy' : (len(atoms)),
             'energy'      : (1),
             'forces'      : (len(atoms), dim),
             'virial'      : (1, dim, dim),
             'dipole'      : (1, dim),
+            'polarizability': (1, dim, dim)
         }
-        for key in ['site_energy', 'energy', 'forces', 'virial', 'dipole']:
+        for key in properties:
             if key in atoms.info:
-                graph[key + '_t'] = torch.tensor(atoms.info[key], dtype=EnvPara.FLOAT_PRECISION, device=device).reshape(padding_shape[key])
-                graph['has_' + key] = True
+                data[key + '_t'] = torch.tensor(atoms.info[key], dtype=EnvPara.FLOAT_PRECISION).reshape(padding_shape[key])
+                data[key + '_weight'] = torch.ones(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION)
             else:
-                graph[key + '_t'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION, device=device)
-                graph['has_' + key] = False
-        return graph
+                data[key + '_t'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION)
+                data[key + '_weight'] = torch.zeros(padding_shape[key], dtype=EnvPara.FLOAT_PRECISION)
+        return data
+
+    def __init__(self,
+                 indices: Optional[List[int]]=None,
+                 cutoff : float=4.0,
+                 ) -> None:
+        self.indices = indices
+        self.cutoff = cutoff
 
-    @property
-    def processed_dir(self) -> str:
-        return os.path.join(self.root, f"processed_{self.cutoff:.2f}".replace(".", "_"))
+    def __len__(self):
+        if self.indices:
+            return len(self.indices)
 
-    @property
-    def per_energy_mean(self):
-        per_energy = self.data["energy_t"] / self.data["n_atoms"]
-        return torch.mean(per_energy)
+    @abc.abstractmethod
+    def __getitem__(self, idx: int):
+        pass
 
-    @property
-    def per_energy_std(self):
-        per_energy = self.data["energy_t"] / self.data["n_atoms"]
-        return torch.std(per_energy)
+    def subset(self, indices: List[int]):
+        ds = copy.copy(self)
+        if ds.indices:
+            ds.indices = [ds.indices[i] for i in indices]
+        else:
+            ds.indices = indices
+        return ds
 
-    @property
-    def forces_std(self):
-        return torch.std(self.data["forces_t"])
+def atoms_collate_fn(batch):
 
-    @property
-    def n_neighbor_mean(self):
-        n_neighbor = self.data['edge_index'].shape[1] / len(self.data['x'])
-        return n_neighbor
+    elem = batch[0]
+    coll_batch = {}
 
-    @property
-    def all_elements(self):
-        return torch.unique(self.data['atomic_number'])
+    for key in elem:
+        if key not in ["idx_i", "idx_j"]:
+            coll_batch[key] = torch.cat([d[key] for d in batch], dim=0)
 
-    def load(self, name: str, device: str="cpu") -> None:
-        self.data, self.slices = torch.load(name, map_location=device)
+    # idx_i and idx_j should to be converted like
+    # [0, 0, 1, 1] + [0, 0, 1, 2] -> [0, 0, 1, 1, 2, 2, 3, 4]
+    for key in ["idx_i", "idx_j"]:
+        coll_batch[key] = torch.cat(
+            [batch[i][key] + torch.sum(coll_batch["n_atoms"][:i]) for i in range(len(batch))], dim=0
+        )
 
-    def load_split(self, train_split: str, test_split: str):
-        train_idx = np.loadtxt(train_split, dtype=int)
-        test_idx  = np.loadtxt(test_split, dtype=int)
-        return self.copy(train_idx), self.copy(test_idx)
+    coll_batch["batch"] = torch.repeat_interleave(
+        torch.arange(len(batch)),
+        repeats=coll_batch["n_atoms"].to(torch.long),
+        dim=0
+    )
 
-    def random_split(self, train_num: int, test_num: int):
-        assert train_num + test_num <= len(self)
-        idx = np.random.choice(len(self), train_num + test_num, replace=False)
-        train_idx = idx[:train_num]
-        test_idx  = idx[train_num:]
-        return self.copy(train_idx), self.copy(test_idx)
+    return coll_batch
diff --git a/tensornet/data/data_interface.py b/tensornet/data/data_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a28b1c142b9037ba4a408d17a16a20e97a0b63
--- /dev/null
+++ b/tensornet/data/data_interface.py
@@ -0,0 +1,182 @@
+import logging
+import os
+import numpy as np
+from copy import copy
+from typing import Optional, List, Dict, Tuple, Union
+from . import *
+from ase.io import read
+import pytorch_lightning as pl
+import torch
+from torch.utils.data import DataLoader
+
+
+log = logging.getLogger(__name__)
+
+
+class LitAtomsDataset(pl.LightningDataModule):
+
+    def __init__(self, p_dict):
+        super().__init__()
+        self.p_dict = p_dict
+        self._train_dataloader = None
+        self._test_dataloader = None
+        self.stats = {}
+
+    def setup(self, stage: Optional[str] = None):
+        self.dataset = self.get_dataset()
+        self.trainset, self.testset = self.split_dataset()
+        # self.calculate_stats()
+
+    def get_dataset(self):
+        data_dict = self.p_dict['Data']
+        # if data_dict['type'] == 'rmd17':
+        #     dataset = RevisedMD17(data_dict['path'], 
+        #                         data_dict['name'], 
+        #                         cutoff=p_dict['cutoff'], 
+        #                         device=p_dict['device'])
+        if data_dict['type'] == 'ase':
+            if 'name' in data_dict:
+                frames = read(os.path.join(data_dict['path'], data_dict['name']), index=':')
+            else:
+                frames = []
+            dataset = ASEData(frames=frames,
+                              properties=self.p_dict['Train']['targetProp'],
+                              cutoff=self.p_dict['cutoff'])
+        elif data_dict['type'] == 'ase-db':
+            dataset = ASEDBData(datapath=os.path.join(data_dict['path'], data_dict['name']),
+                                properties=self.p_dict['Train']['targetProp'],
+                                cutoff=self.p_dict['cutoff'])
+        return dataset
+
+    def split_dataset(self):
+        data_dict = self.p_dict['Data']
+        if ("trainSplit" in data_dict) and ("testSplit" in data_dict):
+            log.info("Load split from {} and {}".format(data_dict["trainSplit"], data_dict["testSplit"]))
+            train_idx = np.loadtxt(data_dict["trainSplit"], dtype=int)
+            test_idx  = np.loadtxt(data_dict["testSplit"], dtype=int)
+            return self.dataset.subset(train_idx), self.dataset.subset(test_idx)
+
+        if ("trainNum" in data_dict) and (("testNum" in data_dict)):
+            log.info("Random split, train num: {}, test num: {}".format(data_dict["trainNum"], data_dict["testNum"]))
+            assert data_dict['trainNum'] + data_dict['testNum'] <= len(self.dataset)
+            idx = np.random.choice(len(self.dataset), data_dict['trainNum'] + data_dict['testNum'], replace=False)
+            train_idx = idx[:data_dict['trainNum']]
+            test_idx  = idx[data_dict['trainNum']:]
+            return self.dataset.subset(train_idx), self.dataset.subset(test_idx)
+
+        if ("trainSet" in data_dict) and ("testSet" in data_dict):
+            assert data_dict['type'] == 'ase', "trainSet and testSet must can be read by ase!"
+            trainset = read(os.path.join(data_dict['path'], data_dict['trainSet']), index=':')
+            self.dataset.extend(trainset)
+            testset = read(os.path.join(data_dict['path'], data_dict['testSet']), index=':')
+            self.dataset.extend(testset)
+            train_idx = [i for i in range(len(trainset))]
+            test_idx = [i for i in range(len(trainset), len(trainset) + len(testset))]
+            return self.dataset.subset(train_idx), self.dataset.subset(test_idx)
+
+        raise Exception("No splitting!")
+
+    def train_dataloader(self):
+        if self._train_dataloader is None:
+            self._train_dataloader = DataLoader(self.trainset,
+                          batch_size=self.p_dict["Data"]["trainBatch"],
+                          shuffle=True,
+                          collate_fn=atoms_collate_fn,
+                          num_workers=self.p_dict["Data"]["numWorkers"],
+                          pin_memory=self.p_dict["Data"]["pinMemory"])
+            log.debug(f'numWorkers: {self.p_dict["Data"]["numWorkers"]}')
+        return self._train_dataloader
+
+    def val_dataloader(self):
+        return self.test_dataloader()
+
+    def test_dataloader(self):
+        if self._test_dataloader is None:
+            self._test_dataloader = DataLoader(self.testset,
+                          batch_size=self.p_dict["Data"]["testBatch"],
+                          shuffle=False,
+                          collate_fn=atoms_collate_fn,
+                          num_workers=self.p_dict["Data"]["numWorkers"],
+                          pin_memory=self.p_dict["Data"]["pinMemory"])
+        return self._test_dataloader
+
+    def calculate_stats(self):
+        # To be noticed, we assume that the average force is always 0, 
+        # so the final result may differ from the actual variance
+
+        N_batch = 0
+        N_forces = 0
+        per_energy_mean = 0.
+        n_neighbor_mean = 0.
+        per_energy_std = 0.
+        forces_std = 0.
+        all_elements = torch.tensor([], dtype=torch.long)#, device=self.p_dict['device'])
+
+        for batch_data in self.train_dataloader():
+            # all elemetns
+            all_elements = torch.unique(torch.cat((all_elements, batch_data['atomic_number'])))
+
+            # per_energy_mean
+            batch_size = batch_data["energy_t"].numel()
+            pe = batch_data["energy_t"] / batch_data["n_atoms"]
+            pe_mean = torch.mean(pe)
+            delta_pe_mean = pe_mean - per_energy_mean
+            per_energy_mean += delta_pe_mean * batch_size / (N_batch + batch_size)
+
+            # per_energy_std
+            pe_m2 = torch.sum((pe - pe_mean) ** 2)
+            pe_corr = batch_size * N_batch / (N_batch + batch_size)
+            per_energy_std += pe_m2 + delta_pe_mean ** 2 * pe_corr
+
+            # n_neighbor_mean
+            nn_mean = batch_data["idx_i"].shape[0] / torch.sum(batch_data["n_atoms"])
+            delta_nn_mean = nn_mean - n_neighbor_mean
+            n_neighbor_mean += delta_nn_mean * batch_size / (N_batch + batch_size)
+            N_batch += batch_size
+
+            # forces_std
+            if 'forces_t' in batch_data:
+                forces_size = batch_data["forces_t"].numel()
+                forces_m2 = torch.sum(batch_data["forces_t"] ** 2)
+                forces_std += forces_m2
+                N_forces += forces_size
+
+        per_energy_std = torch.sqrt(per_energy_std / N_batch)
+        if N_forces > 0:
+            forces_std = torch.sqrt(forces_std / N_forces)
+
+        self.stats["per_energy_mean"] = per_energy_mean
+        self.stats["per_energy_std"] = per_energy_std
+        self.stats["n_neighbor_mean"] = n_neighbor_mean
+        self.stats["forces_std"] = forces_std
+        self.stats["all_elements"] = all_elements
+
+    @property
+    def per_energy_mean(self):
+        if "per_energy_mean" not in self.stats:
+            self.calculate_stats()
+        return self.stats["per_energy_mean"]
+
+    @property
+    def per_energy_std(self):
+        if "per_energy_std" not in self.stats:
+            self.calculate_stats()
+        return self.stats["per_energy_std"]
+
+    @property
+    def forces_std(self):
+        if "forces_std" not in self.stats:
+            self.calculate_stats()
+        return self.stats["forces_std"]
+
+    @property
+    def n_neighbor_mean(self):
+        if "forces_std" not in self.stats:
+            self.calculate_stats()
+        return self.stats["n_neighbor_mean"]
+
+    @property
+    def all_elements(self):
+        if "all_elements" not in self.stats:
+            self.calculate_stats()
+        return self.stats["all_elements"]
diff --git a/tensornet/entrypoints/eval.py b/tensornet/entrypoints/eval.py
index 9c66d98758ef567ae067e73f1dfdbea8273b1401..5ce7b04ed1c18013a766a2b13f52b9c8562c72e1 100644
--- a/tensornet/entrypoints/eval.py
+++ b/tensornet/entrypoints/eval.py
@@ -1,8 +1,8 @@
-from ase.io import read                                                                                        
+from ase.io import read
 import numpy as np
 import torch
-from torch_geometric.loader import DataLoader
-from tensornet.data import ASEData, PtData
+from torch.utils.data import DataLoader
+from tensornet.data import ASEData, ASEDBData, atoms_collate_fn
 
 
 def eval(model, data_loader, properties):
@@ -23,18 +23,33 @@ def eval(model, data_loader, properties):
     return None
 
 
-def main(*args, cutoff=None, model='model.pt', device='cpu', dataset='data.traj', format=None, 
-         properties=["energy", "forces"], batchsize=32, **kwargs):
-    if '.pt' in dataset:
-        dataset = PtData(dataset, device=device)
+def main(*args, modelfile='model.pt', indices=None, device='cpu', datafile='data.traj',
+         format=None, properties=["energy", "forces"], batchsize=32, num_workers=4, pin_memory=True,
+         **kwargs):
+    model = torch.load(modelfile, map_location=device)
+    cutoff = float(model.cutoff.detach().cpu().numpy())
+    if indices is not None:
+        indices = np.loadtxt(indices, dtype=int)
+
+    if '.db' in datafile:
+        dataset = ASEDBData(datapath=datafile,
+                            indices=indices,
+                            properties=properties,
+                            cutoff=cutoff)
     else:
-        assert cutoff is not None, "Must have cutoff!!"
-        frames = read(dataset, index=':', format=format)
-        dataset = ASEData(frames, name="eval_process", cutoff=cutoff, device=device)
-    data_loader = DataLoader(dataset, batch_size=batchsize, shuffle=False)
-    model = torch.load(model) 
-    eval(model, data_loader, properties)
+        frames = read(datafile, index=':', format=format)
+        dataset = ASEData(frames=frames,
+                          indices=indices,
+                          cutoff=cutoff,
+                          properties=properties)
 
+    data_loader = DataLoader(dataset,
+                             batch_size=batchsize,
+                             shuffle=False,
+                             collate_fn=atoms_collate_fn,
+                             num_workers=num_workers,
+                             pin_memory=pin_memory)
+    eval(model, data_loader, properties)
 
 if __name__ == "__main__":
     main()
diff --git a/tensornet/entrypoints/main.py b/tensornet/entrypoints/main.py
index 41887de2ad010a6c3941d1e531a80416d77dadc3..c75c3350f355af1f7d7c93775bd7ca8be4bf24ff 100644
--- a/tensornet/entrypoints/main.py
+++ b/tensornet/entrypoints/main.py
@@ -62,22 +62,22 @@ def parse_args():
     # eval
     parser_eval = subparsers.add_parser(
         "eval",
-        help="train",
+        help="eval",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser_eval.add_argument(
         "-m",
-        "--model",
-        type=str, 
+        "--modelfile",
+        type=str,
         default="model.pt",
         help="model"
     )
     parser_eval.add_argument(
-        "-c",
-        "--cutoff",
-        type=float, 
+        "-i",
+        "--indices",
+        type=str,
         default=None,
-        help="cutoff"
+        help="indices"
     )
     parser_eval.add_argument(
         "--device",
@@ -87,22 +87,22 @@ def parse_args():
     )
     parser_eval.add_argument(
         "-d",
-        "--dataset",
-        type=str, 
+        "--datafile",
+        type=str,
         default="data.traj",
         help="dataset"
     )
     parser_eval.add_argument(
         "-f",
         "--format",
-        type=str, 
+        type=str,
         default=None,
         help="format"
     )
     parser_eval.add_argument(
         "-b",
         "--batchsize",
-        type=int, 
+        type=int,
         default=32,
         help="batchsize"
     )
@@ -114,6 +114,17 @@ def parse_args():
         default=["energy", "forces"],
         help="target properties"
     )
+    parser_eval.add_argument(
+        "--num_workers",
+        type=int,
+        default=4,
+        help="num workder"
+    )
+    parser_eval.add_argument(
+        "--pin_memory",
+        action="store_true",
+        help="pin memory"
+    )
     # clean
     parser_clean = subparsers.add_parser(
         "clean",
diff --git a/tensornet/entrypoints/train.py b/tensornet/entrypoints/train.py
index 356b8622587e2cc9655d6a042cb238394c52ed5c..d8f43bd608f1b9988194a0a83c186aceb623dff0 100644
--- a/tensornet/entrypoints/train.py
+++ b/tensornet/entrypoints/train.py
@@ -1,23 +1,42 @@
 import logging, time, yaml, os
 import numpy as np
-from torch_geometric.loader import DataLoader
 import torch
 from torch import nn
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
 import torch.nn.functional as F
 from torch.optim.swa_utils import AveragedModel
 from ase.data import atomic_numbers
 from ..utils import setup_seed
-from ..loss import Loss, MissingValueLoss, ForceScaledLoss
-from ..model import MiaoNet
+from ..model import MiaoNet, LitAtomicModule
 from ..layer.cutoff import *
 from ..layer.embedding import AtomicEmbedding
 from ..layer.radial import *
-from ..data import *
+from ..data import LitAtomsDataset
 
 
+torch.set_float32_matmul_precision("high")
 log = logging.getLogger(__name__)
 
 
+class SaveModelCheckpoint(ModelCheckpoint):
+    """
+    Saves model.pt for eval
+    """
+    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
+        super()._save_checkpoint(trainer, filepath)
+        modelpath = filepath[:-4] + "pt"
+        if trainer.is_global_zero:
+            torch.save(trainer.lightning_module.model, modelpath)
+
+    def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
+        super()._remove_checkpoint(trainer, filepath)
+        modelpath = filepath[:-4] + "pt"
+        if trainer.is_global_zero:
+            if os.path.exists(modelpath):
+                os.remove(modelpath)
+
+
 def update_dict(d1, d2):
     for key in d2:
         if key in d1 and isinstance(d1[key], dict):
@@ -27,6 +46,41 @@ def update_dict(d1, d2):
     return d1
 
 
+def get_stats(data_dict, dataset):
+
+    if type(data_dict["mean"]) is float:
+        mean = data_dict["mean"]
+    else:
+        try:
+            mean = dataset.per_energy_mean.detach().cpu().numpy()
+        except:
+            mean = 0.
+
+    if data_dict["std"] == "force":
+        std = dataset.forces_std.detach().cpu().numpy()
+    elif data_dict["std"] == "energy":
+        std = dataset.per_energy_std.detach().cpu().numpy()
+    else:
+        assert type(data_dict["std"]) is float, "std must be 'force', 'energy' or a float!" 
+        std = data_dict["std"]
+
+    if type(data_dict["nNeighbor"]) is float:
+        n_neighbor = data_dict["nNeighbor"]
+    else:
+        n_neighbor = dataset.n_neighbor_mean.detach().cpu().numpy()
+
+    if isinstance(data_dict["elements"], list):
+        elements = data_dict["elements"]
+    else:
+        elements = list(dataset.all_elements.detach().cpu().numpy())
+
+    log.info(f"mean  : {mean}")
+    log.info(f"std   : {std}")
+    log.info(f"n_neighbor   : {n_neighbor}")
+    log.info(f"all_elements : {elements}")
+    return mean, std, n_neighbor, elements
+
+
 def get_cutoff(p_dict):
     cutoff = p_dict['cutoff']
     cut_dict = p_dict['Model']['CutoffLayer']
@@ -75,6 +129,9 @@ def get_model(p_dict, elements, mean, std, n_neighbor):
         target_way["site_energy"] = 0
     if "dipole" in target:
         target_way["dipole"] = 1
+    if "polarizability" in target:
+        target_way["polar_diagonal"] = 0
+        target_way["polar_off_diagonal"] = 2
     if "direct_forces" in target:
         assert "forces" not in target_way, "Cannot learn forces and direct_forces at the same time"
         target_way["direct_forces"] = 1
@@ -96,218 +153,6 @@ def get_model(p_dict, elements, mean, std, n_neighbor):
     return model
 
 
-def get_dataset(p_dict):
-    data_dict = p_dict['Data']
-    if data_dict['type'] == 'rmd17':
-        dataset = RevisedMD17(data_dict['path'], 
-                              data_dict['name'], 
-                              cutoff=p_dict['cutoff'], 
-                              device=p_dict['device'])
-    if data_dict['type'] == 'ase':
-        if 'name' in data_dict:
-            dataset = ASEData(root=data_dict['path'],
-                              name=data_dict['name'],
-                              cutoff=p_dict['cutoff'],
-                              device=p_dict['device'])
-        else:
-            dataset = None
-    return dataset
-
-
-def split_dataset(dataset, p_dict):
-    data_dict = p_dict['Data']
-    if ("trainSplit" in data_dict) and ("testSplit" in data_dict):
-        log.info("Load split from {} and {}".format(data_dict["trainSplit"], data_dict["testSplit"]))
-        return dataset.load_split(data_dict["trainSplit"], data_dict["testSplit"])
-    if ("trainNum" in data_dict) and (("testNum" in data_dict)):
-        log.info("Random split, train num: {}, test num: {}".format(data_dict["trainNum"], data_dict["testNum"]))
-        return dataset.random_split(data_dict["trainNum"], data_dict["testNum"])
-    if ("trainSet" in data_dict) and ("testSet" in data_dict):
-        assert data_dict['type'] == 'ase', "trainset must can be read by ase!"
-        trainset = ASEData(root=data_dict['path'], 
-                           name=data_dict['trainSet'],
-                           cutoff=p_dict['cutoff'],
-                           device=p_dict['device'])
-        testset = ASEData(root=data_dict['path'], 
-                          name=data_dict['testSet'],
-                          cutoff=p_dict['cutoff'],
-                          device=p_dict['device'])
-        return trainset, testset
-    raise Exception("No splitting!")
-
-
-def get_loss_calculator(p_dict):
-    train_dict = p_dict['Train']
-    target = train_dict['targetProp']
-    weight = train_dict['weight']
-    weights = {p: w for p, w in zip(target, weight)}
-    if "direct_forces" in weights:
-        weights["forces"] = weights.pop("direct_forces") # direct forces use the same key of forces
-    if train_dict['allowMissing']:
-        # TODO: rewrite loss function
-        if train_dict['forceScale'] > 0:
-            raise Exception("Now forceScale not support allowMissing!")
-        return MissingValueLoss(weights, loss_fn=F.mse_loss)
-    else:
-        if train_dict['forceScale'] > 0:
-            return ForceScaledLoss(weights, loss_fn=F.mse_loss, scaled=train_dict['forceScale'])
-        else:
-            return Loss(weights, loss_fn=F.mse_loss)
-
-
-def eval(model, properties, loss_calculator, data_loader):
-    total = []
-    prop_loss = {}
-    for i_batch, batch_data in enumerate(data_loader):
-        model(batch_data, properties, create_graph=False)
-        loss, loss_dict = loss_calculator.get_loss(batch_data, verbose=True)
-        total.append(loss.detach().cpu().numpy())
-        for prop in loss_dict:
-            if prop not in prop_loss:
-                prop_loss[prop] = []
-            prop_loss[prop].append(loss_dict[prop].detach().cpu().numpy())
-    t1 = np.sqrt(np.mean(total))
-    for prop in loss_dict:
-        prop_loss[prop] = np.sqrt(np.mean(prop_loss[prop]))
-    return t1, prop_loss
-
-    
-def train(model, loss_calculator, optimizer, lr_scheduler, ema, train_loader, test_loader, p_dict):
-    min_loss = 10000
-    t2 = 10000
-    t = time.time()
-    content = f"{'epoch':^10}|{'time':^10}|{'lr':^10}|{'total':^21}"
-    for prop in p_dict["Train"]['targetProp']:
-        content += f"|{prop:^21}"
-    log.info(content)
-    for epoch in range(p_dict["Train"]["epoch"]):
-        for i_batch, batch_data in enumerate(train_loader):
-            optimizer.zero_grad()
-            model(batch_data, p_dict["Train"]['targetProp'])
-            loss = loss_calculator.get_loss(batch_data)
-            loss.backward()
-            if p_dict["Train"]["maxGradNorm"] is not None:
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=p_dict["Train"]["maxGradNorm"])
-            optimizer.step()
-            if ema is not None:
-                ema.update_parameters(model)
-        lr_scheduler.step(epoch=epoch, metrics=t2)
-        if epoch % p_dict["Train"]["logInterval"] == 0:
-            lr = optimizer.param_groups[0]["lr"]
-            t1, prop_loss1 = eval(model, p_dict["Train"]['targetProp'], loss_calculator, train_loader)
-            if p_dict["Train"]["evalTest"]:
-                if ema:
-                    t2, prop_loss2 = eval(ema, p_dict["Train"]['targetProp'], loss_calculator, test_loader)
-                else:
-                    t2, prop_loss2 = eval(model, p_dict["Train"]['targetProp'], loss_calculator, test_loader)
-            else:
-                t2, prop_loss2 = t1, prop_loss1
-            content = f"{epoch:^10}|{time.time() - t:^10.2f}|{lr:^10.2e}|{t1:^10.4f}/{t2:^10.4f}"
-            t = time.time()
-            for prop in p_dict["Train"]['targetProp']:
-                prop = "forces" if prop == "direct_forces" else prop
-                content += f"|{prop_loss1[prop]:^10.4f}/{prop_loss2[prop]:^10.4f}"
-            log.info(content)
-            if t2 < min_loss:
-                min_loss = t2
-                save_checkpoints(p_dict["outputDir"], "best", model, ema, optimizer, lr_scheduler)
-        if epoch > p_dict["Train"]["saveStart"] and epoch % p_dict["Train"]["saveInterval"] == 0:
-            save_checkpoints(p_dict["outputDir"], epoch, model, ema, optimizer, lr_scheduler)
-
-
-def save_checkpoints(path, name, model, ema, optimizer, lr_scheduler):
-    checkpoint = {
-        "model": model.state_dict(),
-        "optimizer": optimizer.state_dict(),
-        "lr_scheduler": lr_scheduler.state_dict(),
-    }
-    if ema is not None:
-        checkpoint["ema"] = ema.state_dict()
-    torch.save(checkpoint, os.path.join(path, f"state_dict-{name}.pt"))
-    torch.save(model, os.path.join(path, f"model-{name}.pt"))
-
-
-def get_optimizer(p_dict, model):
-    opt_dict = p_dict["Train"]["Optimizer"]
-    decay_interactions = {}
-    no_decay_interactions = {}
-    for name, param in model.son_equivalent_layers.named_parameters():
-        if "weight" in name:
-            decay_interactions[name] = param
-        else:
-            no_decay_interactions[name] = param
-
-    param_options = dict(
-        params=[
-            {
-                "name": "embedding",
-                "params": model.embedding_layer.parameters(),
-                "weight_decay": 0.0,
-            },
-            {
-                "name": "interactions_decay",
-                "params": list(decay_interactions.values()),
-                "weight_decay": opt_dict["weightDecay"],
-            },
-            {
-                "name": "interactions_no_decay",
-                "params": list(no_decay_interactions.values()),
-                "weight_decay": 0.0,
-            },
-            {
-                "name": "readouts",
-                "params": model.readout_layer.parameters(),
-                "weight_decay": 0.0,
-            },
-        ],
-        lr=opt_dict["learningRate"],
-        amsgrad=opt_dict["amsGrad"],
-    )
-
-    if opt_dict['type'] == "Adam":
-        return torch.optim.Adam(**param_options)
-    elif opt_dict['type'] == "AdamW":
-        return torch.optim.AdamW(**param_options)
-    else:
-        raise Exception("Unsupported optimizer: {}!".format(opt_dict["type"]))
-
-def get_lr_scheduler(p_dict, optimizer):
-    class LrScheduler:
-        def __init__(self, p_dict, optimizer) -> None:
-            lr_dict = p_dict["Train"]["LrScheduler"]
-            self.mode = lr_dict['type']
-            if lr_dict['type'] == "exponential":
-                self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, 
-                                                                           gamma=lr_dict['gamma'])
-            elif lr_dict['type'] == "reduceOnPlateau":
-                self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
-                                                                               factor=lr_dict['lrFactor'],
-                                                                               patience=lr_dict['patience'])
-            elif lr_dict['type'] == "constant":
-                self.lr_scheduler = None
-            else:
-                raise Exception("Unsupported LrScheduler: {}!".format(lr_dict['type']))
-
-        def step(self, metrics=None, epoch=None):
-            if self.mode == "exponential":
-                self.lr_scheduler.step(epoch=epoch)
-            elif self.mode == "reduceOnPlateau":
-                self.lr_scheduler.step(metrics=metrics, epoch=epoch)
-
-        def state_dict(self):
-            if self.lr_scheduler:
-                return {key: value for key, value in self.lr_scheduler.__dict__.items() if key != 'optimizer'}
-
-        def load_state_dict(self, state_dict):
-            if self.lr_scheduler:
-                self.lr_scheduler.__dict__.update(state_dict)
-
-        def __repr__(self):
-            return self.lr_scheduler.__repr__()
-
-    return LrScheduler(p_dict, optimizer)
-
-
 def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, **kwargs):
     # Default values
     p_dict = {
@@ -320,6 +165,11 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None,
             "trainBatch": 32,
             "testBatch": 32,
             "std": "force",
+            "mean": None,
+            "nNeighbor": None,
+            "elements": None,
+            "numWorkers": 0,
+            "pinMemory": False,
         },
         "Model": {
             "mode": "normal",
@@ -342,20 +192,23 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None,
             }
         },
         "Train": {
+            "maxEpoch": 10000,
+            "maxStep": 1000000,
+            "learningRate": 0.001,
             "allowMissing": False,
             "targetProp": ["energy", "forces"],
             "weight": [0.1, 1.0],
             "forceScale": 0.,
-            "logInterval": 100,
-            "saveInterval": 500,
+            "evalStepInterval": 50,
+            "evalEpochInterval": 1,
+            "logInterval": 50,
             "saveStart": 1000,
             "evalTest": True,
-            "maxGradNorm": None,
+            "gradClip": None,
             "Optimizer": {
                 "type": "Adam",
                 "amsGrad": True,
                 "weightDecay": 0.,
-                "learningRate": 0.001,
                 },
             "LrScheduler": {
                 "type": "constant",
@@ -373,65 +226,52 @@ def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None,
 
     setup_seed(p_dict["seed"])
     log.info("Using seed {}".format(p_dict["seed"]))
-    
-    dataset = get_dataset(p_dict)
-    trainset, testset = split_dataset(dataset, p_dict)
-    train_loader = DataLoader(trainset, batch_size=p_dict["Data"]["trainBatch"], shuffle=True)
-    test_loader = DataLoader(testset, batch_size=p_dict["Data"]["testBatch"], shuffle=False)
-    if dataset is None:
-        dataset = trainset
 
-    try:
-        mean = dataset.per_energy_mean.detach().cpu().numpy()
-    except:
-        mean = 0.
-    if p_dict["Data"]["std"] == "force": 
-        std = dataset.forces_std.detach().cpu().numpy()
-    elif p_dict["Data"]["std"] == "energy":
-        std = dataset.per_energy_std.detach().cpu().numpy()
-    else:
-        std = p_dict["Data"]["std"]
-    n_neighbor = dataset.n_neighbor_mean
-    elements = list(dataset.all_elements.detach().cpu().numpy())
-    log.info(f"mean  : {mean}")
-    log.info(f"std   : {std}")
-    log.info(f"n_neighbor   : {n_neighbor}")
-    log.info(f"all_elements : {elements}")
-    if load_model is not None:
+    log.info(f"Preparing data...")
+    dataset = LitAtomsDataset(p_dict)
+    dataset.setup()
+    mean, std, n_neighbor, elements = get_stats(p_dict["Data"], dataset)
+
+    if load_model is not None and 'ckpt' not in load_model:
+        log.info(f"Load model from {load_model}")
         model = torch.load(load_model)
     else:
         model = get_model(p_dict, elements, mean, std, n_neighbor)
         model.register_buffer('all_elements', torch.tensor(elements, dtype=torch.long))
         model.register_buffer('cutoff', torch.tensor(p_dict["cutoff"], dtype=torch.float64))
 
-    optimizer = get_optimizer(p_dict, model)
-    lr_scheduler = get_lr_scheduler(p_dict, optimizer)
+    if load_model is not None and 'ckpt' in load_model:
+        lit_model = LitAtomicModule.load_from_checkpoint(load_model, model=model, p_dict=p_dict)
+    else:
+        lit_model = LitAtomicModule(model=model, p_dict=p_dict)
+
+    logger = pl.loggers.TensorBoardLogger(save_dir=p_dict["outputDir"])
+    callbacks = [
+        SaveModelCheckpoint(
+            dirpath=p_dict["outputDir"],
+            filename='{epoch}-{step}-{val_loss:.4f}',
+            save_top_k=5,
+            monitor="val_loss"
+        ),
+        LearningRateMonitor()
+    ]
+    trainer = pl.Trainer(
+        logger=logger,
+        callbacks=callbacks,
+        default_root_dir='.',
+        max_epochs=p_dict["Train"]["maxEpoch"],
+        max_steps=p_dict["Train"]["maxStep"],
+        enable_progress_bar=False,
+        log_every_n_steps=p_dict["Train"]["logInterval"],
+        val_check_interval=p_dict["Train"]["evalStepInterval"],
+        check_val_every_n_epoch=p_dict["Train"]["evalEpochInterval"],
+        gradient_clip_val=p_dict["Train"]["gradClip"],
+        )
     if load_checkpoint is not None:
-        state_dict = torch.load(load_checkpoint)
-        model.load_state_dict(state_dict["model"])
-        optimizer.load_state_dict(state_dict["optimizer"])
-        lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
-
-    log.info(" Network Architecture ".center(100, "="))
-    log.info(model)
-    log.info(f"Number of parameters: {sum([p.numel() for p in model.parameters()])}")
-    log.info(" Optimizer ".center(100, "="))
-    log.info(optimizer)
-    # log.info(" LRScheduler ".center(100, "="))
-    # log.info(lr_scheduler)
-
-    ema_decay = p_dict["Train"]["emaDecay"]
-    if ema_decay > 0:
-        ema_avg = lambda averaged_para, para, n: ema_decay * averaged_para + (1 - ema_decay) * para
-        ema = AveragedModel(model=model, device=p_dict["device"], avg_fn=ema_avg, use_buffers=False)
-        # log.info(" ExponentialMovingAverage ".center(80, "="))
-        # log.info(ema)
+        log.info(f"Load checkpoints from {load_checkpoint}")
+        trainer.fit(lit_model, datamodule=dataset, ckpt_path=load_checkpoint)
     else:
-        ema = None
-    log.info("=" * 100)
-    loss_calculator = get_loss_calculator(p_dict)
-    train(model, loss_calculator, optimizer, lr_scheduler, ema, train_loader, test_loader, p_dict)
-
+        trainer.fit(lit_model, datamodule=dataset)
 
 if __name__ == "__main__":
     main()
diff --git a/tensornet/layer/embedding.py b/tensornet/layer/embedding.py
index 84dd1b1fceb28cb2af00088880127bbca3366edd..b6e4c88bdf8403a35218cc7399d7c5e51ce1e560 100644
--- a/tensornet/layer/embedding.py
+++ b/tensornet/layer/embedding.py
@@ -104,7 +104,8 @@ class BehlerG1(EmbeddingLayer):
                 batch_data  : Dict[str, torch.Tensor],
                 ) -> torch.Tensor:
         n_atoms = batch_data['atomic_number'].shape[0]
-        idx_i, idx_j = batch_data['edge_index']
+        idx_i = batch_data['idx_i']
+        idx_j = batch_data['idx_j']
         _, dij, _ = find_distances(batch_data)
         zij = batch_data['atomic_number'][idx_j].unsqueeze(-1)                      # [n_edge, 1]
         dij = dij.unsqueeze(-1)
diff --git a/tensornet/layer/equivalent.py b/tensornet/layer/equivalent.py
index f75aaee3ce7eeb8ae02403fd3cd8c43870656f9a..39b32ba47a9d9680acc8ac4a2b48e8b5e0f2fdb0 100644
--- a/tensornet/layer/equivalent.py
+++ b/tensornet/layer/equivalent.py
@@ -9,7 +9,7 @@ from torch import nn
 from typing import Dict, Callable, Union
 from .base import RadialLayer, CutoffLayer
 from .activate import TensorActivateDict
-from ..utils import find_distances, expand_to, multi_outer_product, _scatter_add, find_moment
+from ..utils import find_distances, find_moment, _scatter_add, _aggregate
 
 
 # input_tensors be like:
@@ -53,10 +53,10 @@ class TensorLinear(nn.Module):
 
 
 class SimpleTensorAggregateLayer(nn.Module):
-    """ 
+    """
     In this type of layer, the rbf mixing only different if r_way different
     """
-    def __init__(self, 
+    def __init__(self,
                  radial_fn      : RadialLayer,
                  n_channel      : int,
                  max_in_way     : int=2,
@@ -85,36 +85,23 @@ class SimpleTensorAggregateLayer(nn.Module):
                 ) -> Dict[int, torch.Tensor]:
         # These 3 rows are required by torch script
         output_tensors = torch.jit.annotate(Dict[int, torch.Tensor], {})
-        idx_i = batch_data['edge_index'][0]
-        idx_j = batch_data['edge_index'][1]
+        idx_i = batch_data['idx_i']
+        idx_j = batch_data['idx_j']
 
         n_atoms = batch_data['atomic_number'].shape[0]
         _, dij, uij = find_distances(batch_data)
         rbf_ij = self.radial_fn(dij)    # [n_edge, n_rbf]
 
-        for r_way in range(self.max_r_way + 1):
-            fn = self.rbf_mixing_dict[str(r_way)](rbf_ij) # [n_edge, n_channel]
+        for r_way, rbf_mixing in self.rbf_mixing_dict.items():
+            r_way = int(r_way)
+            fn = rbf_mixing(rbf_ij) # [n_edge, n_channel]
             # TODO: WHY!!!!!!!!!! CAO!
             # fn = fn * input_tensor_dict[0]
             moment_tensor = find_moment(batch_data, r_way)  # [n_edge, n_dim, ...]
-            filter_tensor_ = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...]
             for in_way, out_way in self.inout_combinations[r_way]:
                 input_tensor = input_tensors[in_way][idx_j]      # [n_edge, n_channel, n_dim, n_dim, ...]
-                coupling_way = (in_way + r_way - out_way) // 2
-                # method 1 
-                n_way = in_way + r_way - coupling_way + 2
-                input_tensor  = expand_to(input_tensor, n_way, dim=-1)
-                filter_tensor = expand_to(filter_tensor_, n_way, dim=2)
-                output_tensor = input_tensor * filter_tensor
-                # input_tensor:  [n_edge, n_channel, n_dim, n_dim, ...,     1] 
-                # filter_tensor: [n_edge, n_channel,     1,     1, ..., n_dim]  
-                # with (in_way + r_way - coupling_way) dim after n_channel
-                # We should sum up (coupling_way) n_dim
-                if coupling_way > 0:
-                    sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)]
-                    output_tensor = torch.sum(output_tensor, dim=sum_axis)
+                output_tensor = _aggregate(moment_tensor, fn, input_tensor, in_way, r_way, out_way)
                 output_tensor = _scatter_add(output_tensor, idx_i, dim_size=n_atoms) / self.norm_factor
-                # output_tensor = segment_coo(output_tensor, idx_i, dim_size=batch_data.num_nodes, reduce="sum")
 
                 if out_way not in output_tensors:
                     output_tensors[out_way] = output_tensor
@@ -124,7 +111,7 @@ class SimpleTensorAggregateLayer(nn.Module):
 
 
 class TensorAggregateLayer(nn.Module):
-    def __init__(self, 
+    def __init__(self,
                  radial_fn      : RadialLayer,
                  n_channel      : int,
                  max_in_way     : int=2,
@@ -134,7 +121,7 @@ class TensorAggregateLayer(nn.Module):
                  ) -> None:
         super().__init__()
         # get all possible "i, r, o" combinations
-        self.all_combinations = []
+        self.all_combinations = {}
         self.rbf_mixing_dict = nn.ModuleDict()
         for in_way in range(max_in_way + 1):
             for r_way in range(max_r_way + 1):
@@ -142,7 +129,7 @@ class TensorAggregateLayer(nn.Module):
                     out_way = in_way + r_way - 2 * z_way
                     if out_way <= max_out_way:
                         comb = (in_way, r_way, out_way)
-                        self.all_combinations.append(comb)
+                        self.all_combinations[str(comb)] = comb
                         self.rbf_mixing_dict[str(comb)] = nn.Linear(radial_fn.n_features, n_channel, bias=False)
 
         self.radial_fn = radial_fn
@@ -154,36 +141,22 @@ class TensorAggregateLayer(nn.Module):
                 ) -> Dict[int, torch.Tensor]:
         # These 3 rows are required by torch script
         output_tensors = torch.jit.annotate(Dict[int, torch.Tensor], {})
-        idx_i = batch_data['edge_index'][0]
-        idx_j = batch_data['edge_index'][1]
+        idx_i = batch_data['idx_i']
+        idx_j = batch_data['idx_j']
 
         n_atoms = batch_data['atomic_number'].shape[0]
         _, dij, uij = find_distances(batch_data)
         rbf_ij = self.radial_fn(dij)    # [n_edge, n_rbf]
 
-        for in_way, r_way, out_way in self.all_combinations:
-            fn = self.rbf_mixing_dict[str((in_way, r_way, out_way))](rbf_ij) # [n_edge, n_channel]
+        for comb, rbf_mixing in self.rbf_mixing_dict.items():
+            in_way, r_way, out_way = self.all_combinations[comb]
+            fn = rbf_mixing(rbf_ij) # [n_edge, n_channel]
             # TODO: WHY!!!!!!!!!! CAO!
             # fn = fn * input_tensor_dict[0]
             moment_tensor = find_moment(batch_data, r_way)  # [n_edge, n_dim, ...]
-            filter_tensor = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...]
             input_tensor = input_tensors[in_way][idx_j]      # [n_edge, n_channel, n_dim, n_dim, ...]
-            coupling_way = (in_way + r_way - out_way) // 2
-            # method 1 
-            n_way = in_way + r_way - coupling_way + 2
-            input_tensor  = expand_to(input_tensor, n_way, dim=-1)
-            filter_tensor = expand_to(filter_tensor, n_way, dim=2)
-            output_tensor = input_tensor * filter_tensor
-            # input_tensor:  [n_edge, n_channel, n_dim, n_dim, ...,     1] 
-            # filter_tensor: [n_edge, n_channel,     1,     1, ..., n_dim]  
-            # with (in_way + r_way - coupling_way) dim after n_channel
-            # We should sum up (coupling_way) n_dim
-            if coupling_way > 0:
-                sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)]
-                output_tensor = torch.sum(output_tensor, dim=sum_axis)
+            output_tensor = _aggregate(moment_tensor, fn, input_tensor, in_way, r_way, out_way)
             output_tensor = _scatter_add(output_tensor, idx_i, dim_size=n_atoms) / self.norm_factor
-            # output_tensor = segment_coo(output_tensor, idx_i, dim_size=batch_data.num_nodes, reduce="sum")
-
             if out_way not in output_tensors:
                 output_tensors[out_way] = output_tensor
             else:
@@ -193,7 +166,7 @@ class TensorAggregateLayer(nn.Module):
 
 
 class SelfInteractionLayer(nn.Module):
-    def __init__(self, 
+    def __init__(self,
                  input_dim  : int,
                  max_in_way : int,
                  output_dim : int=10,
@@ -219,7 +192,7 @@ class SelfInteractionLayer(nn.Module):
 
 # TODO: cat different way together and use Linear layer to got factor of every channel
 class NonLinearLayer(nn.Module):
-    def __init__(self, 
+    def __init__(self,
                  max_in_way  : int,
                  input_dim   : int,
                  activate_fn : str='jilu',
@@ -254,20 +227,20 @@ class SOnEquivalentLayer(nn.Module):
             self.tensor_aggregate = TensorAggregateLayer(radial_fn=radial_fn,
                                                         n_channel=input_dim,
                                                         max_in_way=max_in_way,
-                                                        max_out_way=max_out_way, 
+                                                        max_out_way=max_out_way,
                                                         max_r_way=max_r_way,
                                                         norm_factor=norm_factor,)
         elif mode == 'simple':
             self.tensor_aggregate = SimpleTensorAggregateLayer(radial_fn=radial_fn,
                                                                n_channel=input_dim,
                                                                max_in_way=max_in_way,
-                                                               max_out_way=max_out_way, 
+                                                               max_out_way=max_out_way,
                                                                max_r_way=max_r_way,
                                                                norm_factor=norm_factor,)
         # input for SelfInteractionLayer and NonLinearLayer is the output of TensorAggregateLayer
         # so the max_in_way should equal to max_out_way of TensorAggregateLayer
-        self.self_interact = SelfInteractionLayer(input_dim=input_dim, 
-                                                  max_in_way=max_out_way, 
+        self.self_interact = SelfInteractionLayer(input_dim=input_dim,
+                                                  max_in_way=max_out_way,
                                                   output_dim=output_dim)
         self.non_linear = NonLinearLayer(activate_fn=activate_fn,
                                          max_in_way=max_out_way,
@@ -285,7 +258,7 @@ class SOnEquivalentLayer(nn.Module):
                               input_tensors : Dict[int, torch.Tensor],
                               batch_data    : Dict[str, torch.Tensor],
                               ) -> Dict[int, torch.Tensor]:
-        output_tensors =  self.tensor_aggregate(input_tensors=input_tensors, 
+        output_tensors =  self.tensor_aggregate(input_tensors=input_tensors,
                                                 batch_data=batch_data)
         # resnet
         for r_way in input_tensors.keys():
diff --git a/tensornet/layer/radial.py b/tensornet/layer/radial.py
index b63c6281e986d208aa6d38cc6ba6b8c273b252b3..327d9fe34cd3d89b68bb6713007410a296a85e11 100644
--- a/tensornet/layer/radial.py
+++ b/tensornet/layer/radial.py
@@ -12,7 +12,7 @@ __all__ = ["ChebyshevPoly",
 
 
 class ChebyshevPoly(RadialLayer):
-    def __init__(self, 
+    def __init__(self,
                  r_max      : float,
                  r_min      : float=0.5,
                  n_max      : int=8,
diff --git a/tensornet/layer/readout.py b/tensornet/layer/readout.py
index 6c48a0487cb8219d8fe0aa5d4cadf47487453160..c40af4b39cc604371df16badeb5519398c020a21 100644
--- a/tensornet/layer/readout.py
+++ b/tensornet/layer/readout.py
@@ -10,7 +10,7 @@ __all__ = ["ReadoutLayer"]
 
 class ReadoutLayer(nn.Module):
 
-    def __init__(self, 
+    def __init__(self,
                  n_dim       : int,
                  target_way  : Dict[str, int]={"site_energy": 0},
                  activate_fn : str="jilu",
@@ -26,7 +26,7 @@ class ReadoutLayer(nn.Module):
             for prop, way in target_way.items()
             })
 
-    def forward(self, 
+    def forward(self,
                 input_tensors : Dict[int, torch.Tensor],
                 ) -> Dict[str, torch.Tensor]:
         output_tensors = torch.jit.annotate(Dict[str, torch.Tensor], {})
diff --git a/tensornet/loss.py b/tensornet/loss.py
index 0e6045f9bb948c422cedb282f6ca12a690ec2138..3fef1f103df483ec4cdb7ec4e005c11ef437dc17 100644
--- a/tensornet/loss.py
+++ b/tensornet/loss.py
@@ -7,7 +7,7 @@ from .utils import expand_to
 class Loss:
 
     atom_prop = ["forces"]
-    structure_prop = ["energy", "virial", "dipole"]
+    structure_prop = ["energy", "virial", "dipole", "polarizability"]
 
     def __init__(self,
                  weight  : Dict[str, float]={"energy": 1.0, "forces": 1.0},
@@ -16,7 +16,7 @@ class Loss:
         self.weight = weight
         self.loss_fn = loss_fn
 
-    def get_loss(self, 
+    def get_loss(self,
                  batch_data : Dict[str, torch.Tensor],
                  verbose    : bool=False):
         loss = {}
@@ -42,12 +42,12 @@ class Loss:
                             prop       : str,
                             ) -> torch.Tensor:
         n_atoms = expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape))
-        return self.loss_fn(batch_data[f'{prop}_p'] / n_atoms, 
+        return self.loss_fn(batch_data[f'{prop}_p'] / n_atoms,
                             batch_data[f'{prop}_t'] / n_atoms)
 
 
 class ForceScaledLoss(Loss):
-    
+
     def __init__(self,
                  weight  : Dict[str, float]={"energy": 1.0, "forces": 1.0},
                  loss_fn : Callable=F.mse_loss,
@@ -72,22 +72,27 @@ class MissingValueLoss(Loss):
                        batch_data : Dict[str, torch.Tensor],
                        prop       : str,
                        ) -> torch.Tensor:
-        idx = batch_data[f'has_{prop}'][batch_data['batch']]
-        if not torch.any(idx):
-            return torch.tensor(0.)
-        if torch.all(idx):
-            return super().atom_prop_loss(batch_data, prop)
-        return self.loss_fn(batch_data[f'{prop}_p'][idx], batch_data[f'{prop}_t'][idx])
+        # idx = batch_data[f'{prop}_weight'][batch_data['batch']]
+        # if not torch.any(idx):
+        #     return torch.tensor(0.)
+        # if torch.all(idx):
+        #     return super().atom_prop_loss(batch_data, prop)
+        # return self.loss_fn(batch_data[f'{prop}_p'][idx], batch_data[f'{prop}_t'][idx])
+        return self.loss_fn(batch_data[f'{prop}_p'] * batch_data[f'{prop}_weight'],
+                            batch_data[f'{prop}_t'] * batch_data[f'{prop}_weight'])
 
     def structure_prop_loss(self,
                             batch_data : Dict[str, torch.Tensor],
                             prop       : str,
                             ) -> torch.Tensor:
-        idx = batch_data[f'has_{prop}']
-        if not torch.any(idx):
-            return torch.tensor(0.)
-        if torch.all(idx):
-            return super().structure_prop_loss(batch_data, prop)
-        n_atoms = expand_to(batch_data['n_atoms'][idx], len(batch_data[f'{prop}_p'].shape))
-        return self.loss_fn(batch_data[f'{prop}_p'][idx] / n_atoms, 
-                            batch_data[f'{prop}_t'][idx] / n_atoms)
+        # idx = batch_data[f'{prop}_weight']
+        # if not torch.any(idx):
+        #     return torch.tensor(0.)
+        # if torch.all(idx):
+        #     return super().structure_prop_loss(batch_data, prop)
+        # n_atoms = expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape))
+        # return self.loss_fn(batch_data[f'{prop}_p'][idx] / n_atoms, 
+        #                     batch_data[f'{prop}_t'][idx] / n_atoms)
+        weight = batch_data[f'{prop}_weight'] / expand_to(batch_data['n_atoms'], len(batch_data[f'{prop}_p'].shape))
+        return self.loss_fn(batch_data[f'{prop}_p'] * weight,
+                            batch_data[f'{prop}_t'] * weight)
diff --git a/tensornet/model/__init__.py b/tensornet/model/__init__.py
index 88b1d019cb1b3a729266f2f0e492ddbcb5b30f46..f9bd3344187480359acf64d3a1b3106736713d29 100644
--- a/tensornet/model/__init__.py
+++ b/tensornet/model/__init__.py
@@ -1,2 +1,3 @@
 from .ani import ANI
-from .miao import MiaoNet
\ No newline at end of file
+from .miao import MiaoNet
+from .model_interface import LitAtomicModule
\ No newline at end of file
diff --git a/tensornet/model/base.py b/tensornet/model/base.py
index 8ac5defe81ec9c96acdb7816af710d584a962602..c477d4ae0085a632f2e18e380ffb5c06052dc832 100644
--- a/tensornet/model/base.py
+++ b/tensornet/model/base.py
@@ -6,7 +6,7 @@ from tensornet.utils import _scatter_add, find_distances, add_scaling
 
 class AtomicModule(nn.Module):
 
-    def __init__(self, 
+    def __init__(self,
                  mean  : float=0.,
                  std   : float=1.,
                  ) -> None:
@@ -14,7 +14,7 @@ class AtomicModule(nn.Module):
         self.register_buffer("mean", torch.tensor(mean).float())
         self.register_buffer("std", torch.tensor(std).float())
 
-    def forward(self, 
+    def forward(self,
                 batch_data   : Dict[str, torch.Tensor],
                 properties   : Optional[List[str]]=None,
                 create_graph : bool=True,
@@ -42,6 +42,13 @@ class AtomicModule(nn.Module):
         #######################################
         if 'dipole' in output_tensors:
             batch_data['dipole_p'] = _scatter_add(output_tensors['dipole'], batch_data['batch'])
+        if 'polar_diagonal' in output_tensors:
+            polar_diagonal = _scatter_add(output_tensors['polar_diagonal'], batch_data['batch'])
+            polar_off_diagonal = _scatter_add(output_tensors['polar_off_diagonal'], batch_data['batch'])
+            # polar_off_diagonal = polar_off_diagonal + polar_off_diagonal.transpose(1, 2)
+            shape = polar_off_diagonal.shape
+            batch_data['polarizability_p'] = polar_diagonal.unsqueeze(-1).unsqueeze(-1) * torch.eye(shape[1], device=polar_diagonal.device).expand(*shape) + polar_off_diagonal
+
         if 'direct_forces' in output_tensors:
             batch_data['forces_p'] = output_tensors['direct_forces'] * self.std
         if ('site_energy' in properties) or ('energies' in properties):
diff --git a/tensornet/model/miao.py b/tensornet/model/miao.py
index 4ffb43561dbd681a0f2633811e91fed601cac007..86f9c82c1f8665766c5bc73cd157255fe20012b7 100644
--- a/tensornet/model/miao.py
+++ b/tensornet/model/miao.py
@@ -26,6 +26,7 @@ class MiaoNet(AtomicModule):
                  mode            : str='normal',
                  ):
         super().__init__(mean=mean, std=std)
+        self.register_buffer("norm_factor", torch.tensor(norm_factor).float())
         self.embedding_layer = embedding_layer
         max_r_way = expand_para(max_r_way, n_layers)
         max_out_way = expand_para(max_out_way, n_layers)
diff --git a/tensornet/model/model_interface.py b/tensornet/model/model_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..985b5744ac2a42c20442536da882489ec87d4343
--- /dev/null
+++ b/tensornet/model/model_interface.py
@@ -0,0 +1,171 @@
+from typing import Optional, Dict, List, Type, Any
+from .base import AtomicModule
+from ..loss import Loss, MissingValueLoss, ForceScaledLoss
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+
+
+class LitAtomicModule(pl.LightningModule):
+
+    def __init__(self,
+                 model: AtomicModule,
+                 p_dict: Dict,
+    ):
+        super().__init__()
+        self.p_dict = p_dict
+        self.model = model
+        self.loss_calculator = self.get_loss_calculator()
+
+        grad_prop = set(['forces', 'virial', 'stress'])
+        self.required_derivatives = len(grad_prop.intersection(self.p_dict["Train"]['targetProp'])) > 0
+        # self.save_hyperparameters(ignore=['model'])
+
+    def forward(self,
+                batch_data   : Dict[str, torch.Tensor],
+                properties   : Optional[List[str]]=None,
+                create_graph : bool=True,
+                ) -> Dict[str, torch.Tensor]:
+        results = self.model(batch_data, properties, create_graph)
+        return results
+
+    def get_loss_calculator(self):
+        train_dict = self.p_dict['Train']
+        target = train_dict['targetProp']
+        weight = train_dict['weight']
+        weights = {p: w for p, w in zip(target, weight)}
+        if "direct_forces" in weights:
+            weights["forces"] = weights.pop("direct_forces") # direct forces use the same key of forces
+        if train_dict['allowMissing']:
+            # TODO: rewrite loss function
+            if train_dict['forceScale'] > 0:
+                raise Exception("Now forceScale not support allowMissing!")
+            return MissingValueLoss(weights, loss_fn=F.mse_loss)
+        else:
+            if train_dict['forceScale'] > 0:
+                return ForceScaledLoss(weights, loss_fn=F.mse_loss, scaled=train_dict['forceScale'])
+            else:
+                return Loss(weights, loss_fn=F.mse_loss)
+
+    def training_step(self, batch, batch_idx):
+        self.model(batch, self.p_dict["Train"]['targetProp'])
+        loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True)
+        self.log("train_loss", loss)
+        for prop in loss_dict:
+            self.log(f'train_{prop}', loss_dict[prop])
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        torch.set_grad_enabled(self.required_derivatives)
+        self.model(batch, self.p_dict["Train"]['targetProp'], create_graph=False)
+        loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True)
+        self.log("val_loss", loss, batch_size=batch['n_atoms'].shape[0])
+        for prop in loss_dict:
+            self.log(f'val_{prop}', loss_dict[prop], batch_size=batch['n_atoms'].shape[0])
+
+    def test_step(self, batch, batch_idx):
+        torch.set_grad_enabled(self.required_derivatives)
+        self.model(batch, self.p_dict["Train"]['targetProp'], create_graph=False)
+        loss, loss_dict = self.loss_calculator.get_loss(batch, verbose=True)
+        loss_dict['test_loss'] = loss
+        self.log_dict(loss_dict)
+        return loss_dict
+
+    def get_optimizer(self):
+        opt_dict = self.p_dict["Train"]["Optimizer"]
+        decay_interactions = {}
+        no_decay_interactions = {}
+        for name, param in self.model.son_equivalent_layers.named_parameters():
+            if "weight" in name:
+                decay_interactions[name] = param
+            else:
+                no_decay_interactions[name] = param
+
+        param_options = dict(
+            params=[
+                {
+                    "name": "embedding",
+                    "params": self.model.embedding_layer.parameters(),
+                    "weight_decay": 0.0,
+                },
+                {
+                    "name": "interactions_decay",
+                    "params": list(decay_interactions.values()),
+                    "weight_decay": opt_dict["weightDecay"],
+                },
+                {
+                    "name": "interactions_no_decay",
+                    "params": list(no_decay_interactions.values()),
+                    "weight_decay": 0.0,
+                },
+                {
+                    "name": "readouts",
+                    "params": self.model.readout_layer.parameters(),
+                    "weight_decay": 0.0,
+                },
+            ],
+            lr=self.p_dict["Train"]["learningRate"],
+            amsgrad=opt_dict["amsGrad"],
+        )
+
+        if opt_dict['type'] == "Adam":
+            return torch.optim.Adam(**param_options)
+        elif opt_dict['type'] == "AdamW":
+            return torch.optim.AdamW(**param_options)
+        else:
+            raise Exception("Unsupported optimizer: {}!".format(opt_dict["type"]))
+
+    def get_lr_scheduler(self, optimizer):
+        lr_dict = self.p_dict["Train"]["LrScheduler"]
+        if lr_dict['type'] == "exponential":
+            return torch.optim.lr_scheduler.ExponentialLR(
+                optimizer=optimizer,
+                gamma=lr_dict['gamma']
+                )
+        elif lr_dict['type'] == "reduceOnPlateau":
+            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+                    optimizer=optimizer,
+                    factor=lr_dict['lrFactor'],
+                    patience=lr_dict['patience']
+                    )
+            if self.p_dict["Train"]["evalEpochInterval"] == 1:
+                lr_scheduler_config = {
+                        "scheduler": scheduler,
+                        "interval": "step",
+                        "monitor": "val_loss",
+                        "frequency": self.p_dict["Train"]["evalStepInterval"],
+                        }
+            else:
+                lr_scheduler_config = {
+                        "scheduler": scheduler,
+                        "interval": "epoch",
+                        "monitor": "val_loss",
+                        "frequency": self.p_dict["Train"]["evalEpochInterval"],
+                        }
+            return lr_scheduler_config
+        elif lr_dict['type'] == "constant":
+            return None
+        else:
+            raise Exception("Unsupported LrScheduler: {}!".format(lr_dict['type']))
+
+    def configure_optimizers(self):
+        optimizer = self.get_optimizer()
+        scheduler = self.get_lr_scheduler(optimizer)
+        if scheduler:
+            return [optimizer], [scheduler]
+        else:
+            return [optimizer]
+
+    # Learning rate warm-up
+    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
+        # update params
+        optimizer.step(closure=optimizer_closure)
+
+        # manually warm up lr without a scheduler
+        if self.trainer.global_step < self.p_dict["Train"]["warmupSteps"]:
+            lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.p_dict["Train"]["warmupSteps"])
+            for pg in optimizer.param_groups:
+                pg["lr"] = lr_scale * self.p_dict["Train"]["learningRate"]
+
+    # def lr_scheduler_step(self, scheduler, metric):
+    #     scheduler.step(epoch=self.current_epoch)  # timm's scheduler need the epoch value
diff --git a/tensornet/utils.py b/tensornet/utils.py
index 377d2a20669637e9a28d3d9d1c1f40be72759e21..15d57cfa7cc577008d97ce4d2715d8c8ff3ba95c 100644
--- a/tensornet/utils.py
+++ b/tensornet/utils.py
@@ -13,14 +13,14 @@ def setup_seed(seed):
     torch.backends.cudnn.deterministic = True
 
 
-def expand_to(t     : torch.Tensor, 
-              n_dim : int, 
+def expand_to(t     : torch.Tensor,
+              n_dim : int,
               dim   : int=-1) -> torch.Tensor:
     """Expand dimension of the input tensor t at location 'dim' until the total dimention arrive 'n_dim'
 
     Args:
         t (torch.Tensor): Tensor to expand
-        n_dim (int): target dimension 
+        n_dim (int): target dimension
         dim (int, optional): location to insert axis. Defaults to -1.
 
     Returns:
@@ -31,7 +31,7 @@ def expand_to(t     : torch.Tensor,
     return t
 
 
-def multi_outer_product(v: torch.Tensor, 
+def multi_outer_product(v: torch.Tensor,
                         n: int) -> torch.Tensor:
     """Calculate 'n' times outer product of vector 'v'
 
@@ -51,7 +51,7 @@ def multi_outer_product(v: torch.Tensor,
 def add_scaling(batch_data  : Dict[str, torch.Tensor],) -> Dict[str, torch.Tensor]:
     if 'has_add_scaling' not in batch_data:
         idx_m = batch_data['batch']
-        idx_i = batch_data['edge_index'][0]
+        idx_i = batch_data['idx_i']
         batch_data['coordinate'] = torch.matmul(batch_data['coordinate'][:, None, :], 
                                                 batch_data['scaling'][idx_m]).squeeze(1)
         if 'offset' in batch_data:
@@ -63,11 +63,11 @@ def add_scaling(batch_data  : Dict[str, torch.Tensor],) -> Dict[str, torch.Tenso
 
 def find_distances(batch_data  : Dict[str, torch.Tensor],) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     if 'rij' not in batch_data:
-        idx_i = batch_data["edge_index"][0]
+        idx_i = batch_data["idx_i"]
         if 'ghost_neigh' in batch_data:                  # neighbor for lammps calculation
             idx_j = batch_data["ghost_neigh"]
         else:
-            idx_j = batch_data["edge_index"][1]
+            idx_j = batch_data["idx_j"]
         if 'offset' in batch_data:
             batch_data['rij'] = batch_data['coordinate'][idx_j] + batch_data['offset'] - batch_data['coordinate'][idx_i]
         else:
@@ -170,3 +170,27 @@ def progress_bar(i: int, n: int, interval: int=100):
     if i % interval == 0:
         ii = int(i / n * 100)
         print(f"\r{ii}%[{'*' * ii}{'-' * (100 - ii)}]", end=' ', file=sys.stderr)
+
+
+@torch.jit.script
+def _aggregate(moment_tensor: torch.Tensor,
+               fn : torch.Tensor,
+               input_tensor:torch.Tensor,
+               in_way : int,
+               r_way : int,
+               out_way: int
+               ) -> torch.Tensor:
+    filter_tensor = moment_tensor.unsqueeze(1) * expand_to(fn, n_dim=r_way + 2) # [n_edge, n_channel, n_dim, n_dim, ...]
+    coupling_way = (in_way + r_way - out_way) // 2
+    n_way = in_way + r_way - coupling_way + 2
+    input_tensor  = expand_to(input_tensor, n_way, dim=-1)
+    filter_tensor = expand_to(filter_tensor, n_way, dim=2)
+    output_tensor = input_tensor * filter_tensor
+    # input_tensor:  [n_edge, n_channel, n_dim, n_dim, ...,     1] 
+    # filter_tensor: [n_edge, n_channel,     1,     1, ..., n_dim]  
+    # with (in_way + r_way - coupling_way) dim after n_channel
+    # We should sum up (coupling_way) n_dim
+    if coupling_way > 0:
+        sum_axis = [i for i in range(in_way - coupling_way + 2, in_way + 2)]
+        output_tensor = torch.sum(output_tensor, dim=sum_axis)
+    return output_tensor
diff --git a/tools/convert.py b/tools/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..197e5ce69645fcbcac9837062db25ae67808a8c6
--- /dev/null
+++ b/tools/convert.py
@@ -0,0 +1,49 @@
+import torch
+from train import get_model
+import yaml
+import sys
+
+
+def convert(ckptfile, device='cpu'):
+    ckpt = torch.load(ckptfile, map_location=device)
+    state_dict = ckpt['state_dict']
+    elements = state_dict['model.all_elements'].detach().cpu().numpy()
+    mean = state_dict['model.mean'].detach().cpu().numpy()
+    std = state_dict['model.std'].detach().cpu().numpy()
+    n_neighbor = state_dict['model.norm_factor'].detach().cpu().numpy()
+    cutoff = state_dict['model.cutoff'].detach().cpu().numpy()
+
+    p_dict = {
+        "cutoff": cutoff,
+        "Model": {
+            "mode": "normal",
+            "activateFn": "silu",
+            "nEmbedding": 64,
+            "nLayer": 5,
+            "maxRWay": 2,
+            "maxOutWay": 2,
+            "nHidden": 64,
+            "targetWay": {0 : 'site_energy'},
+            "CutoffLayer": {
+                "type": "poly",
+                "p": 5,
+            },
+            "RadialLayer": {
+                "type": "besselMLP",
+                "nBasis": 8,
+                "nHidden": [64, 64, 64],
+                "activateFn": "silu",
+            }
+        },
+    }
+    
+    with open('input.yaml') as f:
+        p_dict['Model'].update(yaml.load(f, Loader=yaml.FullLoader))
+    model = get_model(p_dict, elements=elements, mean=mean, std=std, n_neighbor=n_neighbor)
+    model.load_state_dict(state_dict)
+    torch.save('model.pt', model)
+    
+
+if __name__ == "__main__":
+    ckptfile = sys.argv[1]
+    convert(ckptfile)
diff --git a/tools/eval.py b/tools/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81cbafe37eee543a9f8fb54de613e7796fedbcb
--- /dev/null
+++ b/tools/eval.py
@@ -0,0 +1,49 @@
+from ase.io import read                                                                                        
+import numpy as np
+import torch
+from torch_geometric.loader import DataLoader
+from tensornet.data import ASEData, ASEDBData, atoms_collate_fn
+
+
+def eval(model, data_loader, properties, device='cpu'):
+    output = {prop: [] for prop in properties}
+    target = {prop: [] for prop in properties}
+    n_atoms = []
+    for batch_data in data_loader:
+        for key in batch_data:
+            batch_data[key] = batch_data[key].to(device)
+        model(batch_data, properties, create_graph=False)
+        n_atoms.extend(batch_data['n_atoms'].detach().cpu().numpy())
+        for prop in properties:
+            output[prop].extend(batch_data[f'{prop}_p'].detach().cpu().numpy())
+            if f'{prop}_t' in batch_data:
+                target[prop].extend(batch_data[f'{prop}_t'].detach().cpu().numpy())
+    for prop in properties:
+        np.save(f'output_{prop}.npy', np.array(output[prop]))
+        np.save(f'target_{prop}.npy', np.array(target[prop]))
+    np.save('n_atoms.npy', np.array(n_atoms))
+    return None
+
+
+def main(*args, cutoff=None, modelfile='model.pt', device='cpu', datafile='data.traj', format=None, 
+         properties=["energy", "forces"], batchsize=32, num_workers=4, pin_memory=True, **kwargs):
+    model = torch.load(model, map_location=device)
+    cutoff = float(model.cutoff.detach().cpu().numpy())
+    if '.db' in dataset:
+        dataset = ASEDBData(datapath=dataset,
+                            properties=properties,
+                            cutoff=cutoff)
+    else:
+        frames = read(dataset, index=':', format=format)
+        dataset = ASEData(frames=frames, cutoff=cutoff, properties=properties)
+    data_loader = DataLoader(dataset,
+                             batch_size=batchsize,
+                             shuffle=False,
+                             collate_fn=atoms_collate_fn,
+                             num_workers=num_workers,
+                             pin_memory=pin_memory)
+    eval(model, data_loader, properties, device)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/tools/train.py b/tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8543942d9c1f0f29dbf24b4f26afc9c7f5b42bbf
--- /dev/null
+++ b/tools/train.py
@@ -0,0 +1,271 @@
+import logging, time, yaml, os
+import numpy as np
+import torch
+from torch import nn
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint
+import torch.nn.functional as F
+from torch.optim.swa_utils import AveragedModel
+from ase.data import atomic_numbers
+from tensornet.utils import setup_seed
+from tensornet.model import MiaoNet, LitAtomicModule
+from tensornet.layer.cutoff import *
+from tensornet.layer.embedding import AtomicEmbedding
+from tensornet.layer.radial import *
+from tensornet.data import LitAtomsDataset
+
+from tensornet.logger import set_logger
+set_logger(log_path='log.txt', level='DEBUG')
+log = logging.getLogger(__name__)
+
+
+class LogAllLoss(pl.Callback):
+
+    def on_train_epoch_end(self, trainer, pl_module):
+        if trainer.global_rank == 0:
+            epoch = trainer.current_epoch
+
+            if epoch == 0:
+                content = f"{'epoch':^10}|{'lr':^10}|{'total':^21}"
+                for prop in pl_module.p_dict["Train"]['targetProp']:
+                    content += f"|{prop:^21}"
+                log.info(content)
+
+            if epoch % pl_module.p_dict["Train"]["evalInterval"] == 0:
+                lr = trainer.optimizers[0].param_groups[0]["lr"]
+                loss_metrics = trainer.callback_metrics
+                train_loss = loss_metrics['train_loss']
+                val_loss = loss_metrics['val_loss']
+                content = f"{epoch:^10}|{lr:^10.2e}|{train_loss:^10.4f}/{val_loss:^10.4f}"
+                for prop in pl_module.p_dict["Train"]['targetProp']:
+                    prop = "forces" if prop == "direct_forces" else prop
+                    content += f"|{loss_metrics[f'train_{prop}']:^10.4f}/{loss_metrics[f'val_{prop}']:^10.4f}"
+                log.info(content)
+
+def update_dict(d1, d2):
+    for key in d2:
+        if key in d1 and isinstance(d1[key], dict):
+            update_dict(d1[key], d2[key])
+        else:
+            d1[key] = d2[key]
+    return d1
+
+
+def get_cutoff(p_dict):
+    cutoff = p_dict['cutoff']
+    cut_dict = p_dict['Model']['CutoffLayer']
+    if cut_dict['type'] == "cos":
+        return CosineCutoff(cutoff=cutoff)
+    elif cut_dict['type'] == "cos2":
+        return SmoothCosineCutoff(cutoff=cutoff, cutoff_smooth=cut_dict['smoothCutoff'])
+    elif cut_dict['type'] == "poly":
+        return PolynomialCutoff(cutoff=cutoff, p=cut_dict['p'])
+    else:
+        raise Exception("Unsupported cutoff type: {}, please choose from cos, cos2, and poly!".format(cut_dict['type']))
+    
+
+def get_radial(p_dict, cutoff_fn):
+    cutoff = p_dict['cutoff']
+    radial_dict = p_dict['Model']['RadialLayer']
+    if "bessel" in radial_dict['type']:
+        radial_fn = BesselPoly(r_max=cutoff, n_max=radial_dict['nBasis'], cutoff_fn=cutoff_fn)
+    elif "chebyshev" in radial_dict['type']:
+        if "minDist" in radial_dict:
+            r_min = radial_dict['minDist']
+        else:
+            r_min = 0.5
+            log.warning("You are using chebyshev poly as basis function, but does not given 'minDist', "
+                        "this may cause some problems!")
+        radial_fn = ChebyshevPoly(r_max=cutoff, r_min=r_min, n_max=radial_dict['nBasis'], cutoff_fn=cutoff_fn)
+    else:
+        raise Exception("Unsupported radial type: {}!".format(radial_dict['type']))
+    if "MLP" in radial_dict['type']:
+        if radial_dict["activateFn"] == "silu":
+            activate_fn = nn.SiLU()
+        elif radial_dict["activateFn"] == "relu":
+            activate_fn = nn.ReLU()
+        else:
+            raise Exception("Unsupported activate function in radial type: {}!".format(radial_dict["activateFn"]))
+        return MLPPoly(n_hidden=radial_dict['nHidden'], radial_fn=radial_fn, activate_fn=activate_fn)
+    else:
+        return radial_fn
+
+
+def get_model(p_dict, elements, mean, std, n_neighbor):
+    model_dict = p_dict['Model']
+    target = p_dict['Train']['targetProp']
+    target_way = {}
+    if ("energy" in target) or ("forces" in target) or ("virial" in target):
+        target_way["site_energy"] = 0
+    if "dipole" in target:
+        target_way["dipole"] = 1
+    if "direct_forces" in target:
+        assert "forces" not in target_way, "Cannot learn forces and direct_forces at the same time"
+        target_way["direct_forces"] = 1
+    cut_fn = get_cutoff(p_dict)
+    emb = AtomicEmbedding(elements, model_dict['nEmbedding'])  # only support atomic embedding now
+    radial_fn = get_radial(p_dict, cut_fn)
+    model = MiaoNet(embedding_layer=emb,
+                    radial_fn=radial_fn,
+                    n_layers=model_dict['nLayer'],
+                    max_r_way=model_dict['maxRWay'],
+                    max_out_way=model_dict['maxOutWay'],
+                    output_dim=model_dict['nHidden'],
+                    activate_fn=model_dict['activateFn'],
+                    target_way=target_way,
+                    mean=mean,
+                    std=std,
+                    norm_factor=n_neighbor,
+                    mode=model_dict['mode'])
+    return model
+
+
+def save_checkpoints(path, name, model, ema, optimizer, lr_scheduler):
+    checkpoint = {
+        "model": model.state_dict(),
+        "optimizer": optimizer.state_dict(),
+        "lr_scheduler": lr_scheduler.state_dict(),
+    }
+    if ema is not None:
+        checkpoint["ema"] = ema.state_dict()
+    torch.save(checkpoint, os.path.join(path, f"state_dict-{name}.pt"))
+    torch.save(model, os.path.join(path, f"model-{name}.pt"))
+
+def main(*args, input_file='input.yaml', load_model=None, load_checkpoint=None, **kwargs):
+    # Default values
+    p_dict = {
+        "workDir": os.getcwd(),
+        "seed": np.random.randint(0, 100000000),
+        "device": "cuda" if torch.cuda.is_available() else "cpu",
+        "outputDir": os.path.join(os.getcwd(), "outDir"),
+        "Data": {
+            "path": os.getcwd(),
+            "trainBatch": 32,
+            "testBatch": 32,
+            "std": "force",
+            "numWorkers": 0,
+            "pinMemory": False,
+        },
+        "Model": {
+            "mode": "normal",
+            "activateFn": "silu",
+            "nEmbedding": 64,
+            "nLayer": 5,
+            "maxRWay": 2,
+            "maxOutWay": 2,
+            "nHidden": 64,
+            "targetWay": {0 : 'site_energy'},
+            "CutoffLayer": {
+                "type": "poly",
+                "p": 5,
+            },
+            "RadialLayer": {
+                "type": "besselMLP",
+                "nBasis": 8,
+                "nHidden": [64, 64, 64],
+                "activateFn": "silu",
+            }
+        },
+        "Train": {
+            "learningRate": 0.001,
+            "allowMissing": False,
+            "targetProp": ["energy", "forces"],
+            "weight": [0.1, 1.0],
+            "forceScale": 0.,
+            "evalInterval": 10,
+            "saveInterval": 500,
+            "saveStart": 1000,
+            "evalTest": True,
+            "maxGradNorm": None,
+            "Optimizer": {
+                "type": "Adam",
+                "amsGrad": True,
+                "weightDecay": 0.,
+                },
+            "LrScheduler": {
+                "type": "constant",
+            },
+            "emaDecay": 0.,
+        },
+    }
+    with open(input_file) as f:
+        update_dict(p_dict, yaml.load(f, Loader=yaml.FullLoader))
+
+    os.makedirs(p_dict["outputDir"], exist_ok=True)
+
+    with open("allpara.yaml", "w") as f:
+        yaml.dump(p_dict, f)
+
+    setup_seed(p_dict["seed"])
+    log.info("Using seed {}".format(p_dict["seed"]))
+
+    log.info(f"Preparing data...")
+    dataset = LitAtomsDataset(p_dict)
+    dataset.setup()
+
+    try:
+        mean = dataset.per_energy_mean.detach().cpu().numpy()
+    except:
+        mean = 0.
+    if p_dict["Data"]["std"] == "force": 
+        std = dataset.forces_std.detach().cpu().numpy()
+    elif p_dict["Data"]["std"] == "energy":
+        std = dataset.per_energy_std.detach().cpu().numpy()
+    else:
+        assert type(std) is float, "std must be 'force', 'energy' or a float!" 
+        std = p_dict["Data"]["std"]
+    n_neighbor = dataset.n_neighbor_mean.detach().cpu().numpy()
+    elements = list(dataset.all_elements.detach().cpu().numpy())
+    log.info(f"mean  : {mean}")
+    log.info(f"std   : {std}")
+    log.info(f"n_neighbor   : {n_neighbor}")
+    log.info(f"all_elements : {elements}")
+    if load_model is not None:
+        model = torch.load(load_model)
+    else:
+        model = get_model(p_dict, elements, mean, std, n_neighbor)
+        model.register_buffer('all_elements', torch.tensor(elements, dtype=torch.long))
+        model.register_buffer('cutoff', torch.tensor(p_dict["cutoff"], dtype=torch.float64))
+
+    # log.info(" Network Architecture ".center(100, "="))
+    # log.info(model)
+    # log.info(f"Number of parameters: {sum([p.numel() for p in model.parameters()])}")
+    # log.info("=" * 100)
+
+    lit_model = LitAtomicModule(model=model, p_dict=p_dict)
+    from lightning.pytorch.profilers import PyTorchProfiler
+    profiler = PyTorchProfiler(
+            on_trace_ready=torch.profiler.tensorboard_trace_handler('.'),
+            schedule=torch.profiler.schedule(skip_first=10,
+                                             wait=1,
+                                             warmup=1,
+                                             active=20,
+                                             repeat=1))
+
+    logger = pl.loggers.TensorBoardLogger(save_dir='.')
+    callbacks = [
+        ModelCheckpoint(
+            dirpath='outDir',
+            filename='{epoch}-{val_loss:.2f}',
+            every_n_epochs=p_dict["Train"]["evalInterval"],
+            save_top_k=1,
+            monitor="val_loss"
+        ),
+        LogAllLoss(),
+    ]
+    trainer = pl.Trainer(
+        profiler=profiler,
+        logger=logger,
+        callbacks=callbacks,
+        default_root_dir='.',
+        max_epochs=100,
+        enable_progress_bar=False,
+        log_every_n_steps=50,
+        #num_nodes=1, 
+        strategy='ddp_find_unused_parameters_true'
+        )
+
+    trainer.fit(lit_model, datamodule=dataset)
+
+if __name__ == "__main__":
+    main()