00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00028 #ifndef __LWPR_HH
00029 #define __LWPR_HH
00030
00031 #include <lwpr.h>
00032 #include <lwpr_math.h>
00033 #include <lwpr_binio.h>
00034 #include <lwpr_xml.h>
00035 #include <string.h>
00036 #include <vector>
00037
00042 typedef std::vector<double> doubleVec;
00043
00048 class LWPR_Exception {
00049 public:
00050
00053 typedef enum {
00054 OUT_OF_MEMORY,
00055 BAD_INPUT_DIM,
00056 BAD_OUTPUT_DIM,
00057 BAD_INIT_D,
00058 UNKNOWN_KERNEL,
00059 IO_ERROR,
00060 OUT_OF_RANGE,
00061 UNSPECIFIED_ERROR
00062 } Code;
00063
00068 LWPR_Exception(Code code) {
00069 this->code = code;
00070 }
00071
00073 Code getCode() const {
00074 return code;
00075 }
00076
00078 const char *getString() const {
00079 switch(code) {
00080 case OUT_OF_MEMORY:
00081 return "Insufficient memory to allocate storage.";
00082 case BAD_INPUT_DIM:
00083 return "Input dimensionality does not match.";
00084 case BAD_OUTPUT_DIM:
00085 return "Output dimensionality does not match.";
00086 case BAD_INIT_D:
00087 return "Invalid initial distance metric (not positive definite).";
00088 case UNKNOWN_KERNEL:
00089 return "Passed kernel name was not recognised.";
00090 case IO_ERROR:
00091 return "An error occurred during I/O operations.";
00092 case OUT_OF_RANGE:
00093 return "Index parameter out of range.";
00094 default:
00095 return "Oops: Unspecified error.";
00096 }
00097 }
00098
00099 private:
00100
00102 Code code;
00103 };
00104
00105 class LWPR_Object;
00106
00107
00113 class LWPR_ReceptiveFieldObject {
00114 friend class LWPR_Object;
00115
00116 public:
00117
00119 int nReg() const {
00120 return RF->nReg;
00121 }
00122
00124 doubleVec meanX() const {
00125 doubleVec mx(nIn);
00126 memcpy(&mx[0], RF->mean_x, sizeof(double)*nIn);
00127 return mx;
00128 }
00129
00131 doubleVec varX() const {
00132 doubleVec vx(nIn);
00133 memcpy(&vx[0], RF->var_x, sizeof(double)*nIn);
00134 return vx;
00135 }
00136
00138 doubleVec center() const {
00139 doubleVec c(nIn);
00140 memcpy(&c[0], RF->c, sizeof(double)*nIn);
00141 return c;
00142 }
00143
00145 bool trustworthy() const {
00146 return (bool) RF->trustworthy;
00147 }
00148
00150 std::vector<doubleVec> D() const {
00151 std::vector<doubleVec> ds(nIn);
00152 for (int i=0;i<nIn;i++) {
00153 ds[i].resize(nIn);
00154 memcpy(&ds[i][0], RF->D + i*nInS, sizeof(double)*nIn);
00155 }
00156 return ds;
00157 }
00158
00161 std::vector<doubleVec> M() const {
00162 std::vector<doubleVec> ms(nIn);
00163 for (int i=0;i<nIn;i++) {
00164 ms[i].resize(i+1);
00165 memcpy(&ms[i][0], RF->M + i*nInS, sizeof(double)*(i+1));
00166 }
00167 return ms;
00168 }
00169
00171 std::vector<doubleVec> U() const {
00172 std::vector<doubleVec> us(RF->nReg);
00173 for (int i=0;i<RF->nReg;i++) {
00174 us[i].resize(nIn);
00175 memcpy(&us[i][0], RF->U + i*nInS, sizeof(double)*nIn);
00176 }
00177 return us;
00178 }
00179
00181 std::vector<doubleVec> P() const {
00182 std::vector<doubleVec> ps(RF->nReg);
00183 for (int i=0;i<RF->nReg;i++) {
00184 ps[i].resize(nIn);
00185 memcpy(&ps[i][0], RF->P + i*nInS, sizeof(double)*nIn);
00186 }
00187 return ps;
00188 }
00189
00191 double beta0() const {
00192 return RF->beta0;
00193 }
00194
00196 doubleVec beta() const {
00197 doubleVec be(RF->nReg);
00198 memcpy(&be[0], RF->beta, sizeof(double)*RF->nReg);
00199 return be;
00200 }
00201
00203 doubleVec numData() const {
00204 doubleVec nd(RF->nReg);
00205 memcpy(&nd[0], RF->n_data, sizeof(double)*RF->nReg);
00206 return nd;
00207 }
00208
00209
00211 doubleVec slope() const {
00212 doubleVec s(nIn);
00213 doubleVec t(nIn);
00214
00215 if (RF->slopeReady) {
00216 memcpy(&s[0], RF->slope, sizeof(double)*RF->nReg);
00217 } else {
00218
00219
00220
00221 lwpr_math_scalar_vector(&s[0], RF->beta[0], RF->U, nIn);
00222 for (int i=1;i<RF->nReg;i++) {
00223 lwpr_math_scalar_vector(&t[0], RF->beta[i], RF->U + i*nInS, nIn);
00224 for (int j=i-1;j>=0;j--) {
00225
00226 double dp = lwpr_math_dot_product(&t[0],RF->P + j*nInS, nIn);
00227 lwpr_math_add_scalar_vector(&t[0], dp, RF->U + j*nInS, nIn);
00228 }
00229 for (int m=0;m<nIn;m++) s[m]+=t[m];
00230 }
00231 }
00232 return s;
00233 }
00234
00235 private:
00236
00238 LWPR_ReceptiveFieldObject(LWPR_ReceptiveField *rf) {
00239 this->RF = rf;
00240 nIn = rf->model->nIn;
00241 nInS = rf->model->nInStore;
00242 }
00243
00245 const LWPR_ReceptiveField *RF;
00246 int nIn;
00247 int nInS;
00248 };
00249
00250
00254 class LWPR_Object {
00255 public:
00256
00265 LWPR_Object(int nIn, int nOut) {
00266 if (!lwpr_init_model(&model, nIn, nOut, NULL)) {
00267 throw LWPR_Exception(LWPR_Exception::OUT_OF_MEMORY);
00268 }
00269 }
00270
00278 LWPR_Object(const LWPR_Object& otherObj) {
00279 if (!lwpr_duplicate_model(&(this->model), &(otherObj.model))) {
00280 throw LWPR_Exception(LWPR_Exception::OUT_OF_MEMORY);
00281 }
00282 }
00283
00292 LWPR_Object(const char *filename) {
00293 int ok;
00294
00295 ok = lwpr_read_binary(&model, filename);
00296 #if HAVE_LIBEXPAT
00297 if (!ok) {
00298 int numErr, numWar;
00299 numErr = lwpr_read_xml(&model, filename, &numWar);
00300 ok = (numErr == 0);
00301 }
00302 #endif
00303 if (!ok) throw LWPR_Exception(LWPR_Exception::IO_ERROR);
00304 }
00305
00307 ~LWPR_Object() {
00308 lwpr_free_model(&model);
00309 }
00310
00317 int writeXML(const char *filename) {
00318 return lwpr_write_xml(&model, filename);
00319 }
00320
00327 int writeBinary(const char *filename) {
00328 return lwpr_write_binary(&model, filename);
00329 }
00330
00345 doubleVec update(const doubleVec& x, const doubleVec& y) {
00346 doubleVec yp(model.nOut);
00347
00348 if (x.size()!=(unsigned) model.nIn) {
00349 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00350 }
00351
00352 if (y.size()!=(unsigned) model.nOut) {
00353 throw LWPR_Exception(LWPR_Exception::BAD_OUTPUT_DIM);
00354 }
00355
00356 if (!lwpr_update(&model, &x[0], &y[0], &yp[0], NULL)) {
00357 throw LWPR_Exception(LWPR_Exception::OUT_OF_MEMORY);
00358 }
00359 return yp;
00360 }
00361
00372 doubleVec predict(const doubleVec& x, double cutoff = 0.001) {
00373 doubleVec yp(model.nOut);
00374
00375 if (x.size()!=(unsigned) model.nIn) {
00376 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00377 }
00378
00379 lwpr_predict(&model, &x[0], cutoff, &yp[0], NULL, NULL);
00380 return yp;
00381 }
00382
00396 doubleVec predict(const doubleVec& x, doubleVec& confidence, double cutoff = 0.001) {
00397 doubleVec yp(model.nOut);
00398
00399 if (x.size()!=(unsigned) model.nIn) {
00400 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00401 }
00402 if (confidence.size()!=(unsigned) model.nOut) confidence.resize(model.nOut);
00403
00404 lwpr_predict(&model, &x[0], cutoff, &yp[0], &confidence[0], NULL);
00405 return yp;
00406 }
00407
00423 doubleVec predict(const doubleVec& x, doubleVec& confidence, doubleVec& maxW, double cutoff = 0.001) {
00424 doubleVec yp(model.nOut);
00425
00426 if (x.size()!=(unsigned) model.nIn) {
00427 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00428 }
00429 if (confidence.size()!=(unsigned) model.nOut) confidence.resize(model.nOut);
00430 if (maxW.size()!=(unsigned) model.nOut) maxW.resize(model.nOut);
00431
00432 lwpr_predict(&model, &x[0], cutoff, &yp[0], &confidence[0], &maxW[0]);
00433 return yp;
00434 }
00435
00441 void setInitD(double delta) {
00442 if (!lwpr_set_init_D_spherical(&model,delta)) {
00443 throw LWPR_Exception(LWPR_Exception::BAD_INIT_D);
00444 }
00445 }
00446
00454 void setInitD(const doubleVec& initD) {
00455 if (initD.size()==(unsigned) model.nIn) {
00456 if (!lwpr_set_init_D_diagonal(&model,&initD[0])) {
00457 throw LWPR_Exception(LWPR_Exception::BAD_INIT_D);
00458 }
00459 } else if (initD.size()==(unsigned) (model.nIn*model.nIn)) {
00460 if (!lwpr_set_init_D(&model,&initD[0],model.nIn)) {
00461 throw LWPR_Exception(LWPR_Exception::BAD_INIT_D);
00462 }
00463 } else {
00464 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00465 }
00466 }
00467
00469 void setInitAlpha(double alpha) {
00470 lwpr_set_init_alpha(&model,alpha);
00471 }
00472
00474 void wGen(double w_gen) { model.w_gen = w_gen; }
00475
00477 void wPrune(double w_prune) { model.w_prune = w_prune; }
00478
00480 void penalty(double pen) { model.penalty = pen; }
00481
00483 void initLambda(double iLam) { model.init_lambda = iLam; }
00484
00486 void tauLambda(double tLam) { model.tau_lambda = tLam; }
00487
00489 void finalLambda(double fLam) { model.final_lambda = fLam; }
00490
00492 void initS2(double init_s2) { model.init_S2 = init_s2; }
00493
00495 void updateD(bool update) { model.update_D = update ? 1:0; }
00496
00498 void diagOnly(bool dOnly) { model.diag_only = dOnly ? 1:0; }
00499
00501 void useMeta(bool meta) { model.meta = meta ? 1:0; }
00502
00504 void metaRate(double rate) { model.meta_rate = rate; }
00505
00507 void kernel(LWPR_Kernel kern) { model.kernel = kern; }
00508
00510 void kernel(const char *str) {
00511 if (!strcmp(str,"Gaussian")) {
00512 model.kernel = LWPR_GAUSSIAN_KERNEL;
00513 return;
00514 }
00515 if (!strcmp(str,"BiSquare")) {
00516 model.kernel = LWPR_BISQUARE_KERNEL;
00517 return;
00518 }
00519 throw LWPR_Exception(LWPR_Exception::UNKNOWN_KERNEL);
00520 }
00521
00523 int nData() const { return model.n_data; }
00524
00526 int nIn() const { return model.nIn; }
00527
00529 int nOut() const { return model.nOut; }
00530
00532 double wGen() const { return model.w_gen; }
00533
00535 double wPrune() const { return model.w_prune; }
00536
00538 double penalty() const { return model.penalty; }
00539
00541 double initLambda() const { return model.init_lambda; }
00542
00544 double tauLambda() const { return model.tau_lambda; }
00545
00547 double finalLambda() const { return model.final_lambda; }
00548
00550 double initS2() const { return model.init_S2; }
00551
00553 bool updateD() { return (bool) model.update_D; }
00554
00556 bool diagOnly() { return (bool) model.diag_only; }
00557
00559 bool useMeta() { return (bool) model.meta; }
00560
00562 double metaRate() { return model.meta_rate; }
00563
00565 LWPR_Kernel kernel() { return model.kernel; }
00566
00568 doubleVec meanX() {
00569 doubleVec mx(model.nIn);
00570 memcpy(model.mean_x,&mx[0],sizeof(double)*model.nIn);
00571 return mx;
00572 }
00573
00575 doubleVec varX() {
00576 doubleVec vx(model.nIn);
00577 memcpy(model.var_x, &vx[0],sizeof(double)*model.nIn);
00578 return vx;
00579 }
00580
00583 void normIn(const doubleVec& norm) {
00584 if (norm.size()!=(unsigned) model.nIn) {
00585 throw LWPR_Exception(LWPR_Exception::BAD_INPUT_DIM);
00586 }
00587 memcpy(model.norm_in,&norm[0],sizeof(double)*model.nIn);
00588 }
00589
00591 doubleVec normIn() const {
00592 doubleVec norm(model.nIn);
00593 memcpy(&norm[0],model.norm_in,sizeof(double)*model.nIn);
00594 return norm;
00595 }
00596
00599 void normOut(const doubleVec& norm) {
00600 if (norm.size()!=(unsigned) model.nOut) {
00601 throw LWPR_Exception(LWPR_Exception::BAD_OUTPUT_DIM);
00602 }
00603 memcpy(model.norm_out,&norm[0],sizeof(double)*model.nOut);
00604 }
00605
00607 doubleVec normOut() const {
00608 doubleVec norm(model.nOut);
00609 memcpy(&norm[0],model.norm_out,sizeof(double)*model.nOut);
00610 return norm;
00611 }
00612
00614 int numRFS(int outDim) {
00615 if (outDim < 0 || outDim >= model.nOut) return 0;
00616 return model.sub[outDim].numRFS;
00617 }
00618
00620 std::vector<int> numRFS() {
00621 std::vector<int> num(model.nOut);
00622 for (int i=0;i<model.nOut;i++) num[i] = model.sub[i].numRFS;
00623 return num;
00624 }
00625
00638 LWPR_ReceptiveFieldObject getRF(int outDim, int index) const {
00639 if (outDim < 0 || outDim >= model.nOut) {
00640 throw LWPR_Exception(LWPR_Exception::OUT_OF_RANGE);
00641 }
00642 if (index < 0 || index >= model.sub[outDim].numRFS) {
00643 throw LWPR_Exception(LWPR_Exception::OUT_OF_RANGE);
00644 }
00645 return LWPR_ReceptiveFieldObject(model.sub[outDim].rf[index]);
00646 }
00647
00649 LWPR_Model model;
00650 };
00651
00652 #endif