14 using namespace shogun;
19 this->fun_obj=
const_cast<function *
>(f);
31 float64_t eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
34 float64_t sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4.;
38 float64_t alpha, f, fnew, prered, actred, gs;
41 int n = (int) fun_obj->get_nr_variable();
42 int search = 1, iter = 1, inc = 1;
53 delta = cblas_dnrm2(n, g, inc);
57 if (gnorm <= eps*gnorm1)
62 CSignal::clear_cancel();
65 while (iter <= max_iter && search && (!CSignal::cancel_computations()))
67 if (max_train_time > 0 && start_time.
cur_time_diff() > max_train_time)
70 cg_iter = trcg(delta, g, s, r);
73 cblas_daxpy(n, one, s, inc, w_new, inc);
75 gs = cblas_ddot(n, g, inc, s, inc);
76 prered = -0.5*(gs-cblas_ddot(n, s, inc, r, inc));
77 fnew = fun_obj->fun(w_new);
83 snorm = cblas_dnrm2(n, s, inc);
85 delta = CMath::min(delta, snorm);
88 if (fnew - f - gs <= 0)
91 alpha = CMath::max(sigma1, -0.5*(gs/(fnew - f - gs)));
94 if (actred < eta0*prered)
95 delta = CMath::min(CMath::max(alpha, sigma1)*snorm, sigma2*delta);
96 else if (actred < eta1*prered)
97 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma2*delta));
98 else if (actred < eta2*prered)
99 delta = CMath::max(sigma1*delta, CMath::min(alpha*snorm, sigma3*delta));
101 delta = CMath::max(delta, CMath::min(alpha*snorm, sigma3*delta));
103 SG_INFO(
"iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
105 if (actred > eta0*prered)
112 gnorm = cblas_dnrm2(n, g, inc);
113 if (gnorm < eps*gnorm1)
115 SG_SABS_PROGRESS(gnorm, -CMath::log10(gnorm), -CMath::log10(1), -CMath::log10(eps*gnorm1), 6);
122 if (CMath::abs(actred) <= 0 && CMath::abs(prered) <= 0)
127 if (CMath::abs(actred) <= 1.0e-12*CMath::abs(f) &&
128 CMath::abs(prered) <= 1.0e-12*CMath::abs(f))
143 int32_t CTron::trcg(
float64_t delta,
double* g,
double* s,
double* r)
147 int n = (int) fun_obj->get_nr_variable();
152 double rTr, rnewTrnew, alpha, beta, cgtol;
160 cgtol = 0.1* cblas_dnrm2(n, g, inc);
163 rTr = cblas_ddot(n, r, inc, r, inc);
166 if (cblas_dnrm2(n, r, inc) <= cgtol)
171 alpha = rTr/cblas_ddot(n, d, inc, Hd, inc);
172 cblas_daxpy(n, alpha, d, inc, s, inc);
173 if (cblas_dnrm2(n, s, inc) > delta)
175 SG_INFO(
"cg reaches trust region boundary\n");
177 cblas_daxpy(n, alpha, d, inc, s, inc);
179 double std = cblas_ddot(n, s, inc, d, inc);
180 double sts = cblas_ddot(n, s, inc, s, inc);
181 double dtd = cblas_ddot(n, d, inc, d, inc);
182 double dsq = delta*delta;
183 double rad = sqrt(std*std + dtd*(dsq-sts));
185 alpha = (dsq - sts)/(std + rad);
187 alpha = (rad - std)/dtd;
188 cblas_daxpy(n, alpha, d, inc, s, inc);
190 cblas_daxpy(n, alpha, Hd, inc, r, inc);
194 cblas_daxpy(n, alpha, Hd, inc, r, inc);
195 rnewTrnew = cblas_ddot(n, r, inc, r, inc);
196 beta = rnewTrnew/rTr;
197 cblas_dscal(n, beta, d, inc);
198 cblas_daxpy(n, one, r, inc, d, inc);
211 for (int32_t i=1; i<n; i++)
212 if (CMath::abs(x[i]) >= dmax)
213 dmax = CMath::abs(x[i]);