Skip to content

Hessians for a list of wrt #81

@adtzlr

Description

@adtzlr

Idea

Use hessian() for the diagonal entries because this supports the sym-argument. Mixed-partials are evaluated by hand. Only the upper triangle part of the mixed partials is evaluated.

import tensortrax as tr
import tensortrax.math as tm
import numpy as np
from copy import copy

def fun(F, p, J):
    C = tm.dot(tm.transpose(F), F)
    detF = tm.linalg.det(F)
    return detF ** (-2/3) * (tm.trace(C) - 3) + (J - 1) ** 2 + p * (J - detF)

def hessians(fun, wrt, ntrax=0, sym=False, parallel=False):
    
    def inner(*args, **kwargs):
        
        out = []
        for a, b in zip(*np.triu_indices(len(wrt))):
            if a == b:
                symlocal = False
                if sym and hasattr(wrt[a], "size"):
                    symlocal = True
                out.append(
                    tr.hessian(
                        fun, wrt=wrt[a], ntrax=ntrax, sym=symlocal, parallel=parallel
                    )(F, p, J)
                )
            else:
                tensorargs = list(copy(args))
                
                tensorargs[a] = tr.Tensor(args[a], ntrax=ntrax)
                tensorargs[b] = tr.Tensor(args[b], ntrax=ntrax)
                
                tensorargs[a].init(hessian=True, δx=True, Δx=False)
                tensorargs[b].init(hessian=True, δx=False, Δx=True)
                
                out.append(tr.Δδ(fun(*tensorargs, **kwargs)))
            
        return out
    
    return inner

F = np.eye(3)
p = np.array([5])
J = np.array([3])

h = hessians(fun, wrt=[1,2,0], ntrax=0, sym=True, parallel=True)(F, p, J)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions