00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 #ifndef __RVLALG_UMIN_LSQR__
00038 #define __RVLALG_UMIN_LSQR__
00039 
00040 #include "alg.hh"
00041 #include "terminator.hh"
00042 #include "linop.hh"
00043 #include "table.hh"
00044 
00045 using namespace RVLAlg;
00046 
00047 namespace RVLUmin {
00048 
00049   using namespace RVL;
00050   using namespace RVLAlg;    
00051 
00076   template<typename Scalar>
00077   class LSQRStep : public Algorithm {
00078 
00079     typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00080 
00081   public:
00082 
00083     LSQRStep(LinearOp<Scalar> const & _A,
00084          Vector<Scalar> & _x,
00085          Vector<Scalar> const & _b,
00086          atype & _rnorm, 
00087          atype & _nrnorm)
00088       : A(_A), x(_x), b(_b), rnorm(_rnorm), nrnorm(_nrnorm), v(A.getDomain()), alphav(A.getDomain()),
00089     u(A.getRange()), betau(A.getRange()), w(A.getDomain()) { 
00090       
00091       
00092       beta=b.norm();
00093       rnorm=beta;
00094       atype tmp;
00095       if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),beta,tmp)) {
00096     RVLException e;
00097     e<<"Error: LSQRStep constructor\n";
00098     e<<"  RHS has vanishing norm\n";
00099     throw e;
00100       }
00101       u.scale(tmp,b);
00102       A.applyAdjOp(u,v);
00103       alpha = v.norm();
00104       nrnorm = alpha*rnorm;
00105       if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),alpha,tmp)) {
00106     RVLException e;
00107     e<<"Error: LSQRStep constructor\n";
00108     e<<"  Normal residual has vanishing norm\n";
00109     throw e;
00110       }
00111       v.scale(tmp);
00112       w.copy(v);
00113       phibar = beta;
00114       rhobar = alpha;
00115     }
00116       
00120     void run() {
00121       try {
00122     
00123     A.applyOp(v,betau);
00124     
00125     Scalar stmp = -alpha;
00126     betau.linComb(stmp,u);
00127     
00128     beta=betau.norm();
00129     atype tmp;
00130     if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),beta,tmp)) {
00131       RVLException e;
00132       e<<"Error: LSQRStep::run\n";
00133       e<<"  beta vanishes\n";
00134       throw e;
00135     }
00136     
00137     stmp = tmp;
00138     u.scale(stmp,betau);    
00139 
00140     
00141     A.applyAdjOp(u,alphav);
00142     
00143     stmp = -beta;
00144     alphav.linComb(stmp,v);
00145     
00146     alpha = alphav.norm();
00147     if (ProtectedDivision<atype>(ScalarFieldTraits<atype>::One(),alpha,tmp)) {
00148       RVLException e;
00149       e<<"Error: LSQRStep::run\n";
00150       e<<"  beta vanishes\n";
00151       throw e;
00152     }
00153     
00154     stmp=tmp;
00155     v.scale(tmp,alphav);        
00156     
00157     
00158         atype rho = sqrt(rhobar*rhobar + beta*beta);
00159     
00160     atype c = rhobar/rho;
00161     
00162     atype s = beta/rho;
00163     
00164     atype theta = s*alpha;
00165     
00166     rhobar = - c*alpha;
00167     
00168     atype phi = c*phibar;
00169     
00170     phibar = s*phibar;
00171 
00172     
00173     x.linComb(phi/rho,w);
00174     
00175     w.scale(-theta/rho);
00176     w.linComb(ScalarFieldTraits<Scalar>::One(),v);
00177 
00178     
00179     rnorm = phibar;
00180     nrnorm = phibar*alpha*abs(c);
00181 
00182       }
00183       catch (RVLException & e) {
00184     e<<"\ncalled from CGNEStep::run()\n";
00185     throw e;
00186       }
00187      
00188     }
00189 
00190     ~LSQRStep() {}
00191 
00192   private:
00193 
00194     
00195     LinearOp<Scalar> const & A;
00196     Vector<Scalar> & x;
00197     Vector<Scalar> const & b;
00198     atype & rnorm;
00199     atype & nrnorm;
00200 
00201     
00202     Vector<Scalar> u;
00203     Vector<Scalar> v;
00204     Vector<Scalar> betau;
00205     Vector<Scalar> alphav;
00206     Vector<Scalar> w;
00207     atype alpha;
00208     atype beta;
00209     atype rhobar;
00210     atype phibar;
00211   };
00212   
00271   template<typename Scalar>
00272   class LSQRAlg: public Algorithm, public Terminator {
00273 
00274     typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00275 
00276   public:
00277 
00317     LSQRAlg(RVL::Vector<Scalar> & _x, 
00318         LinearOp<Scalar> const & _inA, 
00319         Vector<Scalar> const & _rhs, 
00320         atype & _rnorm,
00321         atype & _nrnorm,
00322         atype _rtol = 100.0*numeric_limits<atype>::epsilon(),
00323         atype _nrtol = 100.0*numeric_limits<atype>::epsilon(),
00324         int _maxcount = 10,
00325         atype _maxstep = numeric_limits<atype>::max(),
00326         ostream & _str = cout)
00327     : inA(_inA), 
00328       x(_x), 
00329       rhs(_rhs), 
00330       rnorm(_rnorm), 
00331       nrnorm(_nrnorm), 
00332       rtol(_rtol), 
00333       nrtol(_nrtol), 
00334       maxstep(_maxstep), 
00335       maxcount(_maxcount), 
00336       count(0), 
00337       proj(false), 
00338       str(_str), 
00339       step(inA,x,rhs,rnorm,nrnorm) 
00340     { x.zero(); }
00341 
00342     ~LSQRAlg() {}
00343 
00344     bool query() { return proj; }
00345 
00346     void run() { 
00347       
00348       vector<string> names(2);
00349       vector<atype *> nums(2);
00350       vector<atype> tols(2);
00351       names[0]="Residual Norm"; nums[0]=&rnorm; tols[0]=rtol;
00352       names[1]="Gradient Norm"; nums[1]=&nrnorm; tols[1]=nrtol;
00353       str<<"========================== BEGIN LSQR =========================\n";
00354       VectorCountingThresholdIterationTable<atype> stop1(maxcount,names,nums,tols,str);
00355       stop1.init();
00356       
00357       
00358       BallProjTerminator<Scalar> stop2(x,maxstep,str);
00359       
00360       OrTerminator stop(stop1,stop2);
00361       
00362       LoopAlg doit(step,stop);
00363       doit.run();
00364       
00365       proj = stop2.query();
00366       if (proj) {
00367     Vector<Scalar> temp(inA.getRange());
00368     inA.applyOp(x,temp);
00369     temp.linComb(-1.0,rhs);
00370     rnorm=temp.norm();
00371     Vector<Scalar> temp1(inA.getDomain());
00372     inA.applyAdjOp(temp,temp1);
00373     nrnorm=temp1.norm();
00374       }
00375       count = stop1.getCount();
00376       str<<"=========================== END LSQR ==========================\n";
00377     }
00378 
00379     int getCount() const { return count; }
00380 
00381   private:
00382 
00383     LinearOp<Scalar> const & inA;  
00384     Vector<Scalar> & x;            
00385     Vector<Scalar> const & rhs;    
00386     atype & rnorm;                 
00387     atype & nrnorm;                
00388     atype rtol;                    
00389     atype nrtol;                   
00390     atype maxstep;                 
00391     int maxcount;                  
00392     int count;                     
00393     mutable bool proj;             
00394     ostream & str;                 
00395     LSQRStep<Scalar> step;         
00396 
00397     
00398     LSQRAlg();
00399     LSQRAlg(LSQRAlg<Scalar> const &);
00400 
00401   };
00402 
00405   template<typename Scalar>
00406   class LSQRPolicyData {
00407 
00408     typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00409   
00410   public:
00411 
00412     atype rtol;
00413     atype nrtol;
00414     atype Delta;
00415     int maxcount;
00416     bool verbose;
00417 
00418     LSQRPolicyData(atype _rtol = numeric_limits<atype>::max(),
00419            atype _nrtol = numeric_limits<atype>::max(),
00420            atype _Delta = numeric_limits<atype>::max(),
00421            int _maxcount = 0,
00422            bool _verbose = false)
00423       : rtol(_rtol), nrtol(_nrtol), Delta(_Delta), maxcount(_maxcount), verbose(_verbose) {}
00424       
00425     LSQRPolicyData(LSQRPolicyData<Scalar> const & a) 
00426       : rtol(a.rtol), nrtol(a.nrtol), Delta(a.Delta), maxcount(a.maxcount), verbose(a.verbose) {}
00427 
00428     ostream & write(ostream & str) const {
00429       str<<"\n";
00430       str<<"==============================================\n";
00431       str<<"LSQRPolicyData: \n";
00432       str<<"rtol      = "<<rtol<<"\n";
00433       str<<"nrtol     = "<<nrtol<<"\n";
00434       str<<"Delta     = "<<Delta<<"\n";
00435       str<<"maxcount  = "<<maxcount<<"\n";
00436       str<<"verbose   = "<<verbose<<"\n";
00437       str<<"==============================================\n";
00438       return str;
00439     }
00440   };
00441 
00464   template<typename Scalar> 
00465   class LSQRPolicy {
00466 
00467     typedef typename ScalarFieldTraits<Scalar>::AbsType atype;
00468 
00469   public:
00489     LSQRAlg<Scalar> * build(Vector<Scalar> & x, 
00490                 LinearOp<Scalar> const & A,
00491                 Vector<Scalar> const & d,
00492                 atype & rnorm,
00493                 atype & nrnorm,
00494                 ostream & str) const {
00495       if (verbose) 
00496     return new LSQRAlg<Scalar>(x,A,d,rnorm,nrnorm,rtol,nrtol,maxcount,Delta,str);
00497       else
00498     return new LSQRAlg<Scalar>(x,A,d,rnorm,nrnorm,rtol,nrtol,maxcount,Delta,nullstr);
00499     }
00500 
00506     void assign(atype _rtol, atype _nrtol, atype _Delta, int _maxcount, bool _verbose) {
00507       rtol=_rtol; nrtol=_nrtol; Delta=_Delta; maxcount=_maxcount; verbose=_verbose;
00508     }
00509 
00511     void assign(Table const & t) {
00512       rtol=getValueFromTable<atype>(t,"LSQR_ResTol");
00513       nrtol=getValueFromTable<atype>(t,"LSQR_GradTol"); 
00514       Delta=getValueFromTable<atype>(t,"TR_Delta");
00515       maxcount=getValueFromTable<int>(t,"LSQR_MaxItn"); 
00516       verbose=getValueFromTable<bool>(t,"LSQR_Verbose");
00517     }
00518 
00520     void assign(LSQRPolicyData<Scalar> const & s) {
00521       rtol=s.rtol;
00522       nrtol=s.nrtol;
00523       Delta=s.Delta;
00524       maxcount=s.maxcount;
00525       verbose=s.verbose;
00526     }
00527 
00532     mutable atype Delta;
00533 
00542     LSQRPolicy(atype _rtol = numeric_limits<atype>::max(),
00543            atype _nrtol = numeric_limits<atype>::max(),
00544            atype _Delta = numeric_limits<atype>::max(),
00545            int _maxcount = 0,
00546            bool _verbose = true)
00547       : Delta(_Delta), rtol(_rtol), nrtol(_nrtol), maxcount(_maxcount), verbose(_verbose), nullstr(0) {}
00548 
00549     LSQRPolicy(LSQRPolicy<Scalar> const & p)
00550       :     Delta(p.Delta), 
00551         rtol(p.rtol), 
00552         nrtol(p.nrtol), 
00553         maxcount(p.maxcount), 
00554         verbose(p.verbose), 
00555         nullstr(0) {}
00556       
00557   private:
00558     mutable atype rtol;
00559     mutable atype nrtol;
00560     mutable int maxcount;
00561     mutable bool verbose;
00562     mutable std::ostream nullstr;
00563   };
00564 }
00565 
00566 #endif