#refine twisted gaussian

from __future__ import print_function
from __future__ import division
from scitbx.array_family import flex
from scitbx.dtmin.minimizer import Minimizer
from scitbx.dtmin.reparams import Reparams
from scitbx.dtmin.bounds import Bounds
import math
from scitbx.dtmin.refinebase import RefineBase
import random
import sys

class RefineTG(RefineBase):
  def __init__(self, x, hessian_type, twist, filename):
    RefineBase.__init__(self)
    self.xy = x
    self.s11 = 1.0
    self.s12 = 1.2
    self.s22 = 2.
    self.twist = twist
    self.hessian_type = hessian_type
    self.filename = filename
    self.fo = open(self.filename, "wb")

  def target(self):
    return twisted_gauss2d0(self.xy, self.s11, self.s12, self.s22, self.twist)

  def get_macrocycle_parameters(self):
    return self.xy

  def set_macrocycle_parameters(self, newx):
    self.xy = newx

  def target_gradient(self):
    f = twisted_gauss2d0(self.xy, self.s11, self.s12, self.s22, self.twist)
    grad_x = analytic_grad_x(self.xy, self.s11, self.s12, self.s22, self.twist)
    grad_y = analytic_grad_y(self.xy, self.s11, self.s12, self.s22, self.twist)
    g = flex.double([grad_x, grad_y])
    return (f, g)

  def target_gradient_hessian(self):
    (f,g) = self.target_gradient()
    h = flex.double(self.nmp * self.nmp, 0)
    h.reshape(flex.grid(self.nmp, self.nmp))
    if self.hessian_type == "identity":
      h[0,0], h[1,1], h[0,1], h[1,0] = 1.,1.,0.,0.
      return (f,g,h,True)
    if self.hessian_type == "macrocycle_large_shifts":
      macrocycle_large_shifts = self.macrocycle_large_shifts()
      h[0,0], h[1,1], h[0,1], h[1,0] = macrocycle_large_shifts[0]**-2, macrocycle_large_shifts[1]**-2.,0.,0.
      return (f,g,h,True)
    if self.hessian_type == "curvatures":
      h[0,0] = analytic_curv_xx(self.xy, self.s11, self.s12, self.s22, self.twist)
      h[1,1] = analytic_curv_yy(self.xy, self.s11, self.s12, self.s22, self.twist)
      h[1,0] = h[0,1] = 0.
      return (f,g,h,True)
    if self.hessian_type == "full":
      h00 = analytic_curv_xx(self.xy, self.s11, self.s12, self.s22, self.twist)
      h11 = analytic_curv_yy(self.xy, self.s11, self.s12, self.s22, self.twist)
      h01 = analytic_curv_xy(self.xy, self.s11, self.s12, self.s22, self.twist)
      # the following check for None is because the analytic curvature calculation implementation taken
      # from lbfgs/dev/twisted_gaussian.py returns None if x == y == 0. This is to avoid divide by zero.
      # There should be an analytical formula for this case but it looks like the original implementer
      # (and myself) gave up trying to find it. In this case we will revert to using the identity matrix.
      if h00 is None or h11 is None or h01 is None:
        h[0,0] = h[1,1] = 1.
        h[0,1] = h[1,0] = 0.
      else:
        h[0,0] = h00
        h[1,1] = h11
        h[0,1] = h01
        h[1,0] = h[0,1]
      return (f,g,h,False)

  def macrocycle_large_shifts(self):
    return [2., 2.]

  def set_macrocycle_protocol(self, macrocycle_protocol):
    if macrocycle_protocol == ["all"]:
      self.nmp = 2
    else:
      self.nmp = 0

  def macrocycle_parameter_names(self):
    return ["parameter 1", "parameter 2"]

#  def reparameterize(self):
#    rep_x = Reparams(True, 5.)
#    rep_y = Reparams(True, 5.)
#    return [rep_x, rep_y]

#  def bounds(self):
#    bnd_x = Bounds()
#    bnd_y = Bounds()
#    bnd_x.on(-5,5)
#    bnd_y.on(-5,5)
#    return [bnd_x, bnd_y]

  def initial_statistics(self):
    #self.fo.write(str(self.xy[0]) + " " + str(self.xy[1]) + "\n")
    pass

  def current_statistics(self):
    self.fo.write(str(self.xy[0]) + " " + str(self.xy[1]) + "\n")
#    print(str(self.xy[0]) + " " + str(self.xy[1]))

  def finalize(self):
    self.fo.close()

def gauss2d0(xy, s11, s12, s22):
  (x,y) = xy
  return -math.log(1/math.sqrt(4*math.pi**2*(s11*s22-s12**2))
           *math.exp(-(s22*x**2-2*s12*x*y+s11*y**2)/(2*(s11*s22-s12**2))))

def twisted_gauss2d0(xy, s11, s12, s22, twist):
  (x,y) = xy
  arg = twist*math.sqrt(x**2+y**2)
  c = math.cos(arg)
  s = math.sin(arg)
  xt = x*c - y*s
  yt = y*c + x*s
  return gauss2d0((xt,yt), s11, s12, s22)

Cos = math.cos
Sin = math.sin
Sqrt = math.sqrt
Pi = math.pi

def analytic_grad_x(xy, s11, s12, s22, twist):
  (x,y) = xy
  if (x == 0 and y == 0): return 0.
  return (
         (-2*s12*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (Cos(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (Cos(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)))/
         (2.*(-s12**2 + s11*s22)))

def analytic_grad_y(xy, s11, s12, s22, twist):
  (x,y) = xy
  if (x == 0 and y == 0): return 0.
  return (
         (-2*s12*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (Cos(twist*Sqrt(x**2 + y**2)) +
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (Cos(twist*Sqrt(x**2 + y**2)) +
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)))/
         (2.*(-s12**2 + s11*s22)))

def analytic_curv_xx(xy, s11, s12, s22, twist):
  (x,y) = xy
  if (x == 0 and y == 0): return 0.
  #if (x == 0 and y == 0): return None
  return (
         (-2*s12*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 +
              (twist**2*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 +
              (twist**2*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(Cos(twist*Sqrt(x**2 + y**2)) -
               (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
               (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))**2
             + 2*s11*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x**3*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (3*twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 -
              (twist**2*x**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x**3*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (3*twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 -
              (twist**2*x**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           4*s12*(Cos(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            ((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/
                Sqrt(x**2 + y**2) + Sin(twist*Sqrt(x**2 + y**2)) -
               (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))**2)
          /(2.*(-s12**2 + s11*s22)))

def analytic_curv_yy(xy, s11, s12, s22, twist):
  (x,y) = xy
  if (x == 0 and y == 0): return 0.
  #if (x == 0 and y == 0): return None
  return (
         (-2*s12*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*y**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 +
              (twist**2*y**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*y**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 +
              (twist**2*y**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*y**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*y**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*y**3*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*y**3*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (3*twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/
                  Sqrt(x**2 + y**2)) - Sin(twist*Sqrt(x**2 + y**2)) -
               (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))**2\
            - 4*s12*(-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/
                 Sqrt(x**2 + y**2)) - Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            (Cos(twist*Sqrt(x**2 + y**2)) +
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*(Cos(twist*Sqrt(x**2 + y**2)) +
               (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
               (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))**2
           )/(2.*(-s12**2 + s11*s22)))

def analytic_curv_xy(xy, s11, s12, s22, twist):
  (x,y) = xy
  if (x == 0 and y == 0): return 0.
  #if (x == 0 and y == 0): return None
  return (
         (2*s11*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 -
              (twist**2*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            (-((twist*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/
                 (x**2 + y**2)**1.5) -
              (twist**2*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) +
              (twist*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 -
              (twist**2*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(y*Cos(twist*Sqrt(x**2 + y**2)) +
              x*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 +
              (twist**2*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(x*Cos(twist*Sqrt(x**2 + y**2)) -
              y*Sin(twist*Sqrt(x**2 + y**2)))*
            ((twist*x*y**2*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2)**1.5 -
              (twist**2*x**2*y*Cos(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*x*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              (twist*x**2*y*Sin(twist*Sqrt(x**2 + y**2)))/
               (x**2 + y**2)**1.5 +
              (twist**2*x*y**2*Sin(twist*Sqrt(x**2 + y**2)))/(x**2 + y**2) -
              (twist*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s22*(Cos(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            (-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(-((twist*y**2*Cos(twist*Sqrt(x**2 + y**2)))/
                 Sqrt(x**2 + y**2)) - Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            ((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) +
              Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) -
           2*s12*(Cos(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*x**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            (Cos(twist*Sqrt(x**2 + y**2)) +
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)) +
           2*s11*((twist*x**2*Cos(twist*Sqrt(x**2 + y**2)))/
               Sqrt(x**2 + y**2) + Sin(twist*Sqrt(x**2 + y**2)) -
              (twist*x*y*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2))*
            (Cos(twist*Sqrt(x**2 + y**2)) +
              (twist*x*y*Cos(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2) -
              (twist*y**2*Sin(twist*Sqrt(x**2 + y**2)))/Sqrt(x**2 + y**2)))/
         (2.*(-s12**2 + s11*s22)))

def run():
  start_points = [(-2.,0.), (2.,2.), (2.,-2.)]
  filename = ""
  # create inputs for the minimizer's run method
  MACRO1 = ["all"]            # protocol for the first macrocycle
  PROTOCOL = [MACRO1,MACRO1,MACRO1]         # overall minimization protocol
  NCYC = 50                   # maximum number of microcycles per macrocycle
  STUDY_PARAMS = False        # flag for calling studyparams procedure

  #run the minimization
  for twist in [0.,0.5]:
    for MINIMIZER in ["bfgs","newton"]:          # minimizer, bfgs or newton
      for hessian_type in ["identity", "macrocycle_large_shifts", "curvatures", "full"]:
        for x in start_points:
          filename = "data/" + ("untwisted/" if twist==0. else "twisted/") + "dtmin_" + MINIMIZER + "_" + hessian_type + "_" + str(x)
          print(filename)
          minimizer = Minimizer(3)    #create minimizer object, 0 for MUTE output see Minimizer.py
          refine_tg = RefineTG(x, hessian_type, twist, filename)       # create refine object
          minimizer.run(refine_tg, PROTOCOL, NCYC, MINIMIZER, STUDY_PARAMS)

if (__name__ == "__main__"):
  run()
