#include "osl/move_probability/featureSet.h"
#include "osl/progress/ml/newProgress.h"
#include "osl/record/csaString.h"
#include "osl/record/csaRecord.h"
#include "osl/hash/hashKey.h"
#include "osl/oslConfig.h"

#include <cppunit/TestCase.h>
#include <cppunit/extensions/HelperMacros.h>

class MPFeatureSetTest : public CppUnit::TestFixture 
{
  CPPUNIT_TEST_SUITE(MPFeatureSetTest);
  CPPUNIT_TEST(testInitialize);
  CPPUNIT_TEST(testPass);
  CPPUNIT_TEST(testGenerateLogProb);
  CPPUNIT_TEST_SUITE_END();
public:
  void setUp()
  {
    osl::OslConfig::setUp();
  }
  void testInitialize();
  void testPass();
  void testGenerateLogProb();
};

CPPUNIT_TEST_SUITE_REGISTRATION(MPFeatureSetTest);

using namespace osl;
using namespace osl::move_probability;

void MPFeatureSetTest::testInitialize()
{
  /* const osl::move_probability::StandardFeatureSet& osl_feature_set
     = */ osl::move_probability::StandardFeatureSet::instance();
}

void MPFeatureSetTest::testPass()
{
  const StandardFeatureSet& feature_set
    = StandardFeatureSet::instance();

  NumEffectState state;
  const Move pass_b = Move::PASS(BLACK);
  const Move pass_w = Move::PASS(WHITE);
  state.makeMove(pass_b);
  state.makeMove(pass_w);

  MoveStack history;
  history.push(pass_b);
  history.push(pass_w);

  typedef progress::ml::NewProgress progress_t;
  progress_t progress(state);
  StateInfo info(state, progress.progress16(), history);

  MoveLogProbVector a;
  feature_set.generateLogProb(info, a);
  CPPUNIT_ASSERT(! a.empty());
}

void MPFeatureSetTest::testGenerateLogProb()
{
  const StandardFeatureSet& feature_set
    = StandardFeatureSet::instance();
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  CPPUNIT_ASSERT(ifs);
  std::string csafilename;
  int i=0;
  while((ifs >> csafilename) && csafilename != "" && ++i<128){
    record::Record record=CsaFile(OslConfig::testCsaFile(csafilename)).getRecord();
    NumEffectState state(record.getInitialState());
    vector<osl::Move> moves=record.getMoves();
    typedef osl::progress::ml::NewProgress progress_t;
    progress_t progress(state);
    MoveStack history;
    StateInfo info(state, progress.progress16(), history);
    for (size_t i=0; i<moves.size(); i++) {
      MoveLogProbVector a;
      feature_set.generateLogProb(info, a);
      CPPUNIT_ASSERT(!a.empty());

      state.makeMove(moves[i]);
      progress.update(state, moves[i]);
      history.push(moves[i]);
      info.reset(state, progress.progress16(), history);
      StateInfo info2(state, progress.progress16(), history);

      {
	const StateInfo& l = info;
	const StateInfo& r = info2;
	for (int x=1; x<=9; ++x) {
	  for (int y=1; y<=9; ++y) {
	    const Square position(x,y);
	    if (! (l.pattern_cache[position.index()]
		   == r.pattern_cache[position.index()])) {
	      std::cerr << state << position
			<< " " << moves[i] << "\n";
	    }
	    CPPUNIT_ASSERT((l.pattern_cache[position.index()]
			    == r.pattern_cache[position.index()]));
	  }
	}
	CPPUNIT_ASSERT(HashKey(*l.state) == HashKey(*r.state));
	CPPUNIT_ASSERT(*l.history == *r.history);
	CPPUNIT_ASSERT(l.pin_by_opposing_sliders == r.pin_by_opposing_sliders);
	CPPUNIT_ASSERT(l.king8_long_pieces == r.king8_long_pieces);
	CPPUNIT_ASSERT(l.threatened == r.threatened);
	CPPUNIT_ASSERT(l.long_attack_cache == r.long_attack_cache);
	CPPUNIT_ASSERT(l.attack_shadow == r.attack_shadow);
	CPPUNIT_ASSERT(l.progress16 == r.progress16);
	CPPUNIT_ASSERT(l.last_move_ptype5 == r.last_move_ptype5);
	CPPUNIT_ASSERT(l.last_add_effect == r.last_add_effect);
	CPPUNIT_ASSERT(l.pin == r.pin);
	CPPUNIT_ASSERT(l.threatmate_move == r.threatmate_move);
	CPPUNIT_ASSERT(l.sendoffs == r.sendoffs);
	CPPUNIT_ASSERT(l.exchange_pins == r.exchange_pins);
	CPPUNIT_ASSERT(l.move_candidate_exists == r.move_candidate_exists);
	CPPUNIT_ASSERT(HashKey(l.copy) == HashKey(r.copy));
      }
      CPPUNIT_ASSERT(info == info2);
    }
  }  
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
