From 549b95b5736b42f3fe74daf9805303a08b1ae01d Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 28 May 2019 16:08:41 +0200 Subject: FGW+gromov changes --- README.md | 2 + examples/plot_fgw.py | 152 +++++++++++++++++++++++++ ot/bregman.py | 2 +- ot/gromov.py | 310 ++++++++++++++++++++++++++++++++++++++++++++++++--- ot/optim.py | 102 ++++++++++++++++- 5 files changed, 546 insertions(+), 22 deletions(-) create mode 100644 examples/plot_fgw.py diff --git a/README.md b/README.md index a22306d..be88f65 100644 --- a/README.md +++ b/README.md @@ -219,3 +219,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + +[18] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML). diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py new file mode 100644 index 0000000..5c2d0e1 --- /dev/null +++ b/examples/plot_fgw.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +============================== +Plot Fused-gromov-Wasserstein +============================== + +This example illustrates the computation of FGW for 1D measures[18]. + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer +# +# License: MIT License + +import matplotlib.pyplot as pl +import numpy as np +import ot +from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein + +#%% parameters +# We create two 1D random measures +n=20 +n2=30 +sig=1 +sig2=0.1 + +np.random.seed(0) + +phi=np.arange(n)[:,None] +xs=phi+sig*np.random.randn(n,1) +ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1) + +phi2=np.arange(n2)[:,None] +xt=phi2+sig*np.random.randn(n2,1) +yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1) +yt= yt[::-1,:] + +p=ot.unif(n) +q=ot.unif(n2) + +#%% plot the distributions + +pl.close(10) +pl.figure(10,(7,7)) + +pl.subplot(2,1,1) + +pl.scatter(ys,xs,c=phi,s=70) +pl.ylabel('Feature value a',fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1) +pl.xticks(()) +pl.yticks(()) +pl.subplot(2,1,2) +pl.scatter(yt,xt,c=phi2,s=70) +pl.xlabel('coordinates x/y',fontsize=25) +pl.ylabel('Feature value b',fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1) +pl.yticks(()) +pl.tight_layout() +pl.show() + + +#%% Structure matrices and across-features distance matrix +C1=ot.dist(xs) +C2=ot.dist(xt).T +M=ot.dist(ys,yt) +w1=ot.unif(C1.shape[0]) +w2=ot.unif(C2.shape[0]) +Got=ot.emd([],[],M) + +#%% +cmap='Reds' +pl.close(10) +pl.figure(10,(5,5)) +fs=15 +l_x=[0,5,10,15] +l_y=[0,5,10,15,20,25] +gs = pl.GridSpec(5, 5) + +ax1=pl.subplot(gs[3:,:2]) + +pl.imshow(C1,cmap=cmap,interpolation='nearest') +pl.title("$C_1$",fontsize=fs) +pl.xlabel("$k$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(l_x) +pl.yticks(l_x) + +ax2=pl.subplot(gs[:3,2:]) + +pl.imshow(C2,cmap=cmap,interpolation='nearest') +pl.title("$C_2$",fontsize=fs) +pl.ylabel("$l$",fontsize=fs) +#pl.ylabel("$l$",fontsize=fs) +pl.xticks(()) +pl.yticks(l_y) +ax2.set_aspect('auto') + +ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1) +pl.imshow(M,cmap=cmap,interpolation='nearest') +pl.yticks(l_x) +pl.xticks(l_y) +pl.ylabel("$i$",fontsize=fs) +pl.title("$M_{AB}$",fontsize=fs) +pl.xlabel("$j$",fontsize=fs) +pl.tight_layout() +ax3.set_aspect('auto') +pl.show() + + +#%% Computing FGW and GW +alpha=1e-3 + +ot.tic() +Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True) +ot.toc() + +#%reload_ext WGW +Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True) + +#%% visu OT matrix +cmap='Blues' +fs=15 +pl.figure(2,(13,5)) +pl.clf() +pl.subplot(1,3,1) +pl.imshow(Got,cmap=cmap,interpolation='nearest') +#pl.xlabel("$y$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(()) + +pl.title('Wasserstein ($M$ only)') + +pl.subplot(1,3,2) +pl.imshow(Gg,cmap=cmap,interpolation='nearest') +pl.title('Gromov ($C_1,C_2$ only)') +pl.xticks(()) +pl.subplot(1,3,3) +pl.imshow(Gwg,cmap=cmap,interpolation='nearest') +pl.title('FGW ($M+C_1,C_2$)') + +pl.xlabel("$j$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) + +pl.tight_layout() +pl.show() \ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index b017c1a..9040429 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,7 +5,7 @@ Bregman projections for regularized OT # Author: Remi Flamary # Nicolas Courty -# +# Titouan Vayer # License: MIT License import numpy as np diff --git a/ot/gromov.py b/ot/gromov.py index 0278e99..7491664 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -9,17 +9,18 @@ Gromov-Wasserstein transport method # Author: Erwan Vautier # Nicolas Courty # RĂ©mi Flamary -# +# Titouan Vayer # License: MIT License import numpy as np + from .bregman import sinkhorn from .utils import dist from .optim import cg -def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): +def init_matrix(C1, C2, p, q, loss_fun='square_loss'): """ Return loss matrices and tensors for Gromov-Wasserstein fast computation Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss @@ -77,16 +78,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss': def f1(a): - return (a**2) / 2 + return (a**2) def f2(b): - return (b**2) / 2 + return (b**2) def h1(a): return a def h2(b): - return b + return 2*b elif loss_fun == 'kl_loss': def f1(a): return a * np.log(a + 1e-15) - a @@ -268,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): """ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) @@ -306,6 +307,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations log : bool, optional record log if True + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. **kwargs : dict parameters can be directly pased to the ot.optim.cg solver @@ -329,9 +333,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): """ - T = np.eye(len(p), len(q)) - - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -342,14 +344,79 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) return res, log else: - return cg(p, q, 0, 1, f, df, G0, **kwargs) + return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs) + +def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs): + """ + Computes the FGW distance between two graphs see [3] + .. math:: + \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p + \gamma^T 1= q + \gamma\geq 0 + where : + - M is the (ns,nt) metric cost matrix + - :math:`f` is the regularization term ( and df is its gradient) + - a and b are source and target weights (sum to 1) + The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters + ---------- + M : ndarray, shape (ns, nt) + Metric cost matrix between features across domains + C1 : ndarray, shape (ns, ns) + Metric cost matrix respresentative of the structure in the source space + C2 : ndarray, shape (nt, nt) + Metric cost matrix espresentative of the structure in the target space + p : ndarray, shape (ns,) + distribution in the source space + q : ndarray, shape (nt,) + distribution in the target space + loss_fun : string,optionnal + loss function used for the solver + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + **kwargs : dict + parameters can be directly pased to the ot.optim.cg solver + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun) + + G0=p[:,None]*q[None,:] + + def f(G): + return gwloss(constC,hC1,hC2,G) + def df(G): + return gwggrad(constC,hC1,hC2,G) + + return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs): """ Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) @@ -387,7 +454,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations log : bool, optional record log if True - + amijo : bool, optional + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. Returns ------- gw_dist : float @@ -407,9 +476,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): """ - T = np.eye(len(p), len(q)) - - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -418,7 +485,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs) + res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs) log['gw_dist'] = gwloss(constC, hC1, hC2, res) log['T'] = res if log: @@ -495,7 +562,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, T = np.outer(p, q) # Initialization - constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) cpt = 0 err = 1 @@ -815,3 +882,210 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, cpt += 1 return C + +def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss', + max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None): + + """ + Compute the fgw barycenter as presented eq (5) in [3]. + ---------- + N : integer + Desired number of samples of the target barycenter + Ys: list of ndarray, each element has shape (ns,d) + Features of all samples + Cs : list of ndarray, each element has shape (ns,ns) + Structure matrices of all samples + ps : list of ndarray, each element has shape (ns,) + masses of all samples + lambdas : list of float + list of the S spaces' weights + alpha : float + Alpha parameter for the fgw distance + fixed_structure : bool + Wether to fix the structure of the barycenter during the updates + fixed_features : bool + Wether to fix the feature of the barycenter during the updates + init_C : ndarray, shape (N,N), optional + initialization for the barycenters' structure matrix. If not set random init + init_X : ndarray, shape (N,d), optional + initialization for the barycenters' features. If not set random init + Returns + ---------- + X : ndarray, shape (N,d) + Barycenters' features + C : ndarray, shape (N,N) + Barycenters' structure matrix + log_: + T : list of (N,ns) transport matrices + Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + S = len(Cs) + d = Ys[0].shape[1] #dimension on the node features + if p is None: + p = np.ones(N)/N + + Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] + Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)] + + lambdas = np.asarray(lambdas, dtype=np.float64) + + if fixed_structure: + if init_C is None: + C=Cs[0] + else: + C=init_C + else: + if init_C is None: + xalea = np.random.randn(N, 2) + C = dist(xalea, xalea) + else: + C = init_C + + if fixed_features: + if init_X is None: + X=Ys[0] + else : + X= init_X + else: + if init_X is None: + X=np.zeros((N,d)) + else: + X = init_X + + T=[np.outer(p,q) for q in ps] + + # X is N,d + # Ys is ns,d + Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + # Ms is N,ns + + cpt = 0 + err_feature = 1 + err_structure = 1 + + if log: + log_={} + log_['err_feature']=[] + log_['err_structure']=[] + log_['Ts_iter']=[] + + while((err_feature > tol or err_structure > tol) and cpt < max_iter): + Cprev = C + Xprev = X + + if not fixed_features: + Ys_temp=[y.T for y in Ys] + X=update_feature_matrix(lambdas,Ys_temp,T,p) + + # X must be N,d + # Ys must be ns,d + Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))] + + if not fixed_structure: + if loss_fun == 'square_loss': + # T must be ns,N + # Cs must be ns,ns + # p must be N,1 + T_temp=[t.T for t in T] + C = update_sructure_matrix(p, lambdas, T_temp, Cs) + + # Ys must be d,ns + # Ts must be N,ns + # p must be N,1 + # Ms is N,ns + # C is N,N + # Cs is ns,ns + # p is N,1 + # ps is ns,1 + + T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] + + # T is N,ns + + log_['Ts_iter'].append(T) + err_feature = np.linalg.norm(X - Xprev.reshape(d,N)) + err_structure = np.linalg.norm(C - Cprev) + + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + cpt += 1 + log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys + log_['p']=p + log_['Ms']=Ms #Ms sont de tailles N,ns + + return X.T,C,log_ + + +def update_sructure_matrix(p, lambdas, T, Cs): + """ + Updates C according to the L2 Loss kernel with the S Ts couplings + calculated at each iteration + Parameters + ---------- + p : ndarray, shape (N,) + masses in the targeted barycenter + lambdas : list of float + list of the S spaces' weights + T : list of S np.ndarray(ns,N) + the S Ts couplings calculated at each iteration + Cs : list of S ndarray, shape(ns,ns) + Metric cost matrices + Returns + ---------- + C : ndarray, shape (nt,nt) + updated C matrix + """ + tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) + ppt = np.outer(p, p) + + return np.divide(tmpsum, ppt) + +def update_feature_matrix(lambdas,Ys,Ts,p): + + """ + Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3] + calculated at each iteration + Parameters + ---------- + p : ndarray, shape (N,) + masses in the targeted barycenter + lambdas : list of float + list of the S spaces' weights + Ts : list of S np.ndarray(ns,N) + the S Ts couplings calculated at each iteration + Ys : list of S ndarray, shape(d,ns) + The features + Returns + ---------- + X : ndarray, shape (d,N) + + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + + p=np.diag(np.array(1/p).reshape(-1,)) + + tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))]) + + return tmpsum + + diff --git a/ot/optim.py b/ot/optim.py index f31fae2..a774865 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,7 +4,7 @@ Optimization algorithms for OT """ # Author: Remi Flamary -# +# Titouan Vayer # License: MIT License import numpy as np @@ -71,9 +71,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 +def do_linesearch(cost,G,deltaG,Mi,f_val, + amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + The FGW cost + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + amijo : bool, optionnal + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optionnal + Structure matrix in the source domain. Only used when amijo=False + C2 : ndarray (nt,nt), optionnal + Structure matrix in the target domain. Only used when amijo=False + reg : float, optionnal + Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used when amijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [3]. Only used when amijo=False + M : ndarray (ns,nt), optionnal + Cost matrix between the features. Only used when amijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if amijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1=np.dot(C1,deltaG) + dot12=dot1.dot(C2) + a=-2*reg*np.sum(dot12*deltaG) + b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG)) + c=cost(G) + + alpha=solve_1d_linesearch_quad_funct(a,b,c) + fc=None + f_val=cost(G+alpha*deltaG) + + return alpha,fc,f_val + def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False): + stopThr=1e-9, verbose=False, log=False,**kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -116,6 +177,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True + kwargs : dict + Parameters for linesearch Returns ------- @@ -177,7 +240,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs) G = G + alpha * deltaG @@ -339,3 +402,36 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, return G, log else: return G + +def solve_1d_linesearch_quad_funct(a,b,c): + """ + Solve on 0,1 the following problem: + .. math:: + \min f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + + """ + f0=c + df0=b + f1=a+f0+df0 + + if a>0: # convex + minimum=min(1,max(0,-b/(2*a))) + #print('entrelesdeux') + return minimum + else: # non convexe donc sur les coins + if f0>f1: + #print('sur1 f(1)={}'.format(f(1))) + return 1 + else: + #print('sur0 f(0)={}'.format(f(0))) + return 0 -- cgit v1.2.3