"""
.. _example-tip4p-water:

4-site water models
===================

.. currentmodule:: torchpme

# Several water models (starting from the venerable TIP4P model of
# `Abascal and C. Vega, JCP (2005) <http://doi.org/10.1063/1.2121687>`_)
# use a center of negative charge that is displaced from the O position.
# This is easily implemented, yielding the forces on the O and H positions
# generated by the displaced charge.
"""

import ase
import torch

import torchpme

structure = ase.Atoms(
    positions=[
        [0, 0, 0],
        [0, 1, 0],
        [1, -0.2, 0],
    ],
    cell=[6, 6, 6],
    symbols="OHH",
)

cell = torch.from_numpy(structure.cell.array)
positions = torch.from_numpy(structure.positions)

# %%
# The key step is to create a "fourth site" based on the oxygen positions and use it in
# the ``interpolate`` step.

charges = torch.tensor([[-1.0], [0.5], [0.5]])

positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)

positions_4site = torch.vstack(
    [
        ((positions[1::3] + positions[2::3]) * 0.5 + positions[0::3] * 3) / 4,
        positions[1::3],
        positions[2::3],
    ]
)

# %%
# .. important::
#
#   For the automatic differentiation to work it is important to make a new tensor as
#   ``positions_4site`` and do not "overwrite" the original tensor.

ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
    cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions_4site)
mesh = interpolator.points_to_mesh(charges)

value = (mesh**2).sum()

# %%
# The gradients can be computed by just running `backward` on the
# end result. Gradients are computed on the H and O positions.

value.backward()

print(
    f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
