#include <Core/Core.h>

#include "FireBird.h"

NAMESPACE_UPP


class FireBirdConnection : public SqlConnection {
protected:
	virtual void        SetParam(int i, const Value& r);
	virtual bool        Execute();
	virtual int         GetRowsProcessed() const;
	virtual Value       GetInsertedId() const;
	virtual bool        Fetch();
	virtual void        GetColumn(int i, Ref f) const;
	virtual void        Cancel();
	virtual SqlSession& GetSession() const;
	virtual String      GetUser() const;
	virtual String      ToString() const;

private:
	FireBirdSession& session;
	
	IBPP::Database  conn;
	Vector<Value>  param;
	int             rows;
	int             fetched_row; //-1, if not fetched yet
	int64			last_id;

	String          ErrorMessage();
	String          ErrorCode();

public:
	FireBirdConnection(FireBirdSession& a_session, IBPP::Database a_conn);
	virtual ~FireBirdConnection() { Cancel(); }
};

const char *FireBirdReadString(const char *s, String& stmt)
{
	//TODO: to clear this, currently this is based on sqlite
	stmt.Cat(*s);
	int c = *s++;
	for(;;) {
		if(*s == '\0') break;
		else
		if(*s == '\'' && s[1] == '\'') {
			stmt.Cat('\'');
			s += 2;
		}
		else
		if(*s == c) {
			stmt.Cat(c);
			s++;
			break;
		}
		else
		if(*s == '\\') {
			stmt.Cat('\\');
			if(*++s)
				stmt.Cat(*s++);
		}
		else
			stmt.Cat(*s++);
	}
	return s;
}

bool FireBirdPerformScript(const String& txt, StatementExecutor& se, Gate2<int, int> progress_canceled)
{
	const char *text = txt;
	for(;;) {
		String stmt;
		while(*text <= 32 && *text > 0) text++;
		if(*text == '\0') break;
		for(;;) {
			if(*text == '\0')
				break;
			if(*text == ';')
				break;
			else
			if(*text == '\'')
				text = FireBirdReadString(text, stmt);
			else
			if(*text == '\"')
				text = FireBirdReadString(text, stmt);
			else
				stmt.Cat(*text++);
		}
		if(progress_canceled(text - txt.Begin(), txt.GetLength()))
			return false;
		if(!se.Execute(stmt))
			return false;
		if(*text) text++;
	}
	return true;
}

String FireBirdConnection::ErrorMessage()
{
	return GetLastErrorMessage();
}

String FireBirdConnection::ErrorCode()
{
	return AsString(GetLastError());//GetLastErrorMessage();
}

String FireBirdSession::ErrorMessage()
{
	return GetLastErrorMessage();
}

String FireBirdSession::ErrorCode()
{
	return AsString(GetLastError());
}

Vector<String> FireBirdSession::EnumUsers()
{
	std::vector<std::string> v;
	conn->Users(v);
	Vector<String> vec;
	for(unsigned int i=0;i<v.size();i++) vec.Add(v[i]);
	return vec;
}

Vector<String> FireBirdSession::EnumDatabases()
{
	Vector<String> vec;
	String dir=GetFileDirectory(conn->DatabaseName());
	FindFile ff(dir+"*.fdb");
	while(ff) { vec.Add(ff.GetName()); ff.Next(); }
	ff.Search(dir+"*.gdb");
	while(ff) { vec.Add(ff.GetName()); ff.Next(); }	
	return vec;
}

Vector<String> FireBirdSession::EnumTables()
{
	Vector<String> vec;
	Sql sql("SELECT rdb$relation_name FROM rdb$relations WHERE (rdb$system_flag IS NULL OR rdb$system_flag = 0) AND rdb$view_blr IS NULL", *this);
	sql.Execute();
	while(sql.Fetch()) vec.Add(sql[0]);
	
	return vec;
}

Vector<String> FireBirdSession::EnumViews()
{
	Vector<String> vec;
	Sql sql("SELECT rdb$relation_name FROM rdb$relations WHERE (rdb$system_flag IS NULL OR rdb$system_flag = 0) AND rdb$view_blr IS NOT NULL", *this);
	sql.Execute();
	while(sql.Fetch()) vec.Add(sql[0]);
	
	return vec;
}

Vector<String> FireBirdSession::EnumGenerators()
{
	Vector<String> vec;
	Sql sql("SELECT rdb$generator_name FROM rdb$generators WHERE (rdb$system_flag IS NULL OR rdb$system_flag = 0)", *this);
	sql.Execute();
	while(sql.Fetch()) vec.Add(sql[0]);
	
	return vec;
}

int SqlTypeToValueType(int type, int precision) {
	switch(type) {
		case 7:
		case 8:
		case 9:
			type=INT_V;
			break;
		case 16:
			if(precision==0 || IsNull(precision)) type=INT_V;
			else type=DOUBLE_V;
			break;				
		case 10:
		case 27:
			type=DOUBLE_V;
			break;				
		case 12:
			type=DATE_V;
			break;				
		case 13:
		case 35:
			type=TIME_V;
			break;				
/*		case 261:
			type=BYTEA_V;
			break;*/
		default:
			type=STRING_V;
	}	
	return type;
}

Vector<SqlColumnInfo> FireBirdSession::EnumColumns(String table)
{
	Vector<SqlColumnInfo> vec;
	
	Sql sql(Format(
			"SELECT rdb$relation_fields.rdb$field_name, rdb$fields.rdb$field_type, rdb$fields.rdb$field_length, rdb$fields.rdb$field_sub_type, rdb$relation_fields.rdb$null_flag, rdb$fields.rdb$character_length, rdb$fields.rdb$field_precision, rdb$fields.rdb$field_scale "
			"FROM rdb$relation_fields INNER JOIN rdb$fields ON rdb$relation_fields.rdb$field_source=rdb$fields.rdb$field_name "
			"WHERE (rdb$relation_fields.rdb$system_flag IS NULL OR rdb$relation_fields.rdb$system_flag = 0) "
			"AND rdb$relation_fields.rdb$relation_name=\'%s\'", ToUpper(table)), *this);
	sql.Execute();
	
	while(sql.Fetch())
	{
		SqlColumnInfo &ci = vec.Add();
		int type_mod = int(sql[3]) - sizeof(int32);
		ci.name = sql[0];
		ci.width = sql[2];
		if(!IsNull(sql[6])) ci.width=sql[5];

		ci.type=SqlTypeToValueType(sql[1], sql[6]);
		
		ci.precision = sql[6];
		ci.scale = sql[7];
		ci.nullable = !IsNull(sql[4]) || (sql[4] == "0");
	}
	return vec;
}

Vector<String> FireBirdSession::EnumPrimaryKey(String database, String table) { return Vector<String>(); }
String         FireBirdSession::EnumRowID(String database, String table) { return Null; }
Vector<String> FireBirdSession::EnumReservedWords() { return Vector<String>(); }

SqlConnection * FireBirdSession::CreateConnection()
{
	return new FireBirdConnection(*this, conn);
}

void FireBirdSession::ExecTrans(const char * st)
{
	if(trace)
		*trace << st << "\n";
	
	try {
		bool tr_started=transaction->Started();
		if(!tr_started) transaction->Start();
		
		statement->Execute(st);
		
		if(!tr_started) transaction->Commit();
		SetError(Null, Null);
	}
	catch(IBPP::SQLException& ex) {
		SetError(ex.ErrorMessage(), st, ex.SqlCode(), ex.Origin()); 

		if(trace)
			*trace << st << " failed: " << ex.ErrorMessage() << "\n";		
		
		transaction->Rollback();
	}
}

bool FireBirdSession::Open(const char *connect)
{
	Close();
	Vector<String> con_prop=Split(String(connect), ';');
	try {
		conn=IBPP::DatabaseFactory(con_prop[0], con_prop[1], con_prop[2], con_prop[3]);
		conn->Connect();
		transaction=IBPP::TransactionFactory(conn);
		statement=IBPP::StatementFactory(conn, transaction);
		service=IBPP::ServiceFactory(con_prop[0], con_prop[2], con_prop[3]);
		SetError(Null, Null);
		return true;
	}
	catch(IBPP::Exception& ex) {
		SetError(ex.ErrorMessage(), "Opening database");
		Close();
		return false;
	}
}

void FireBirdSession::Close()
{
	if(conn==NULL)
		return;
#ifndef flagNOAPPSQL
	if(SQL.IsOpen() && &SQL.GetSession() == this) {
		SQL.Cancel();
		SQL.Detach();
	}
#endif
	conn->Disconnect();
}

void FireBirdSession::Begin()
{
	if(!transaction->Started()) transaction->Start();
}

void FireBirdSession::Commit()
{
	if(transaction->Started()) transaction->Commit();
}

void FireBirdSession::Rollback()
{
	if(transaction->Started()) transaction->Rollback();
}

void FireBirdConnection::SetParam(int i, const Value& r)
{
	param.At(i, r);
}

static void SetParameter(IBPP::Statement st, int i, Value& r) {
	i++;
	if(IsNull(r)) st->SetNull(i);
	else
	switch(r.GetType()) { 
		case BOOL_V:
			st->Set(i, (bool)r);
			break;		
		case INT_V:
			st->Set(i, (int)r);
			break;
		case INT64_V:
			st->Set(i, (int64)r);
			break;			
		case DECIMAL_V: {
			decimal temp=r;
			int64 mult=1;
			for(int j=0;j<temp.PrintPrecision();j++) mult*=10;
			temp*=mult;
			st->Set(i, (int64)temp);
			}
			break;						
		case DOUBLE_V:
			st->Set(i, (double)r);
			break;			
		case STRING_V:
		case WSTRING_V: {
			String temp=r;
			st->Set(i, std::string(temp.Begin(), temp.End())); 
			}
			break;
		case DATE_V: {
			Date date=r;
			IBPP::Date temp(date.year, date.month, date.day);
			st->Set(i, temp);
			}
			break;			
		case TIME_V: 
			Time time=r;
			if(time.month<=0) {
				IBPP::Time temp(time.hour, time.minute, time.second);
				st->Set(i, temp);
			}
			else {
				IBPP::Timestamp temp(time.year, time.month, time.day, time.hour, time.minute, time.second);
				st->Set(i, temp);				
			}
			break;		
	}	
}

bool FireBirdConnection::Execute()
{
	Cancel();
	session.SetError(Null, Null);
	if(statement.GetLength() == 0) {
		session.SetError("Empty statement", statement);
		return false;
	}

	Stream *trace = session.GetTrace();
	dword time;
	if(session.IsTraceTime())
		time = GetTickCount();

	try {
		bool tr_started=session.transaction->Started();
		if(!tr_started) session.transaction->Start();

		try {		
			session.statement->Prepare(statement);
			for(int i=0;i<param.GetCount();i++) 
				SetParameter(session.statement, i, param[i]);
			param.Clear();		
		}
		catch(IBPP::Exception& ex) {
			session.SetError(ex.ErrorMessage(), statement, 0, ex.Origin()); 
			if(trace)
				*trace << statement << " failed: " << ex.ErrorMessage() << "\n";					
			session.transaction->Rollback();
			return false;			
		}				
		
		session.statement->Execute();
		
		if(TrimLeft(ToUpper(statement)).StartsWith("INSERT")) {
			if(ToUpper(statement.Find("RETURNING")>=0)) {
				try {
					session.statement->Get(1, last_id);
				}
				catch(IBPP::Exception& ex) {
					last_id = Null; ex;
				}
			}
			else {
				String table_name = Split(statement, ' ')[2];
				last_id=Null;
				
				try {
					IBPP::Statement st=IBPP::StatementFactory(session.conn, session.transaction);
					st->Execute("SELECT rdb$generator_name FROM rdb$generators WHERE (rdb$system_flag IS NULL OR rdb$system_flag = 0) AND rdb$generator_name LIKE '%"+ToUpper(table_name)+"%'");
					if(st->Fetch()) {
						std::string temp;
						st->Get(1, temp);
						String generator_name(temp);
						generator_name=ToUpper(TrimRight(generator_name));
						st->Execute("SELECT gen_id(\""+generator_name+"\", 0) FROM rdb$database");
						if(st->Fetch()) st->Get(1, last_id);
					}
				}
				catch(IBPP::Exception& ex) {
					last_id = Null; ex;
				}
			}
		}
		
		if(!tr_started && !TrimLeft(ToUpper(statement)).StartsWith("SELECT")) 
			session.transaction->Commit();
	}	
	catch(IBPP::SQLException& ex) {
		session.SetError(ex.ErrorMessage(), statement, ex.SqlCode(), ex.Origin()); 
		if(trace)
			*trace << statement << " failed: " << ex.ErrorMessage() << "\n";		
		session.transaction->Rollback();
		return false;
	}

	if(trace) {
		if(session.IsTraceTime())
			*trace << Format("--------------\nexec %d ms:\n", msecs(time));
	}
	
	fetched_row=-1;
	rows=session.statement->AffectedRows();
	info.Clear();	
	
	return true;
}

int FireBirdConnection::GetRowsProcessed() const
{
	return rows;
}

Value FireBirdConnection::GetInsertedId() const
{		
	return last_id;
}

int ColumnTypeToValueType(int type, int scale=0) {
	switch(type) {
		case IBPP::sdString:
			type=STRING_V;
			break;
		case IBPP::sdSmallint:
		case IBPP::sdInteger:
			type=INT_V;		
			break;
		case IBPP::sdLargeint:
			if(scale==0) type=INT64_V;
			else type=DECIMAL_V;
			break;
		case IBPP::sdFloat:
		case IBPP::sdDouble:
			type=DOUBLE_V;
			break;
		case IBPP::sdDate:
			type=DATE_V;
			break;
		case IBPP::sdTime:
			type=TIME_V;
			break;
		case IBPP::sdTimestamp:			
			type=TIME_V | 0x80;
			break;			
		case IBPP::sdBlob:
			type=STRING_V | 0x80;
			break;
		default:
			type=STRING_V;
	}
	return type;
}

bool FireBirdConnection::Fetch()
{
	try {
		session.SetError(Null, Null);
		fetched_row++;
		if(session.statement->Fetch()) {
			
			if(fetched_row==0) {
				
				int fields = session.statement->Columns();
				info.SetCount(fields);
				for(int i = 1; i <= fields; i++)
				{
					SqlColumnInfo& f = info[i-1];
					f.name = ToUpper(session.statement->ColumnName(i));
					f.width = session.statement->ColumnSize(i);
					if(session.statement->ColumnType(i)==IBPP::sdLargeint) f.precision = 18;
					else f.precision=0;
					f.scale = abs(session.statement->ColumnScale(i));
					f.nullable = true;
					f.type=ColumnTypeToValueType(session.statement->ColumnType(i), f.scale);
				}
			}			
			
			return true;
		}
		session.transaction->Commit();
		Cancel();
		return false;
	}
	catch(IBPP::SQLException& ex) {
		session.SetError(ex.ErrorMessage(), "Error while fetching data", ex.SqlCode());
		return false;
	}
}

void FireBirdConnection::GetColumn(int i, Ref f) const
{
	i++;
	if(session.statement->IsNull(i))
	{
		f = Null;
		return;
	}
	
	switch(info[i-1].type)
	{
		case INT64_V: {
			int64 temp;
			session.statement->Get(i, temp);
			f.SetValue(temp);
			}
			break;
		case INT_V: {
			int temp;
			session.statement->Get(i ,temp);
			f.SetValue(temp);
			}
			break;
		case DECIMAL_V: {
			int64 temp;
			session.statement->Get(i, temp);
			String s=AsString(temp);
			s.Insert(s.GetCount()- info[i-1].scale, '.');
			if(s.StartsWith(".")) s="0"+s;
			decimal d(s, info[i-1].scale);
			f.SetValue(d);
			}
			break;
		case DOUBLE_V: {
			double temp;
			session.statement->Get(i, temp);
			f.SetValue(temp);
			}
			break;
		case DATE_V: {
			IBPP::Date temp;
			session.statement->Get(i, temp);
			Date date(temp.Year(), temp.Month(), temp.Day());
			f.SetValue(date);
			}
			break;
		case TIME_V: {
			IBPP::Time temp;
			session.statement->Get(i, temp);
			Time time;
			time.hour=temp.Hours();
			time.minute=temp.Minutes();
			time.second=temp.Seconds();
			f.SetValue(time);
			}
			break;
		case (TIME_V | 0x80): {
			IBPP::Timestamp temp;
			session.statement->Get(i, temp);
			Time time(temp.Year(), temp.Month(), temp.Day(), temp.Hours(), temp.Minutes(), temp.Seconds());
			f.SetValue(time);
			}
			break;
		case (STRING_V | 0x80): {
			IBPP::Blob temp=IBPP::BlobFactory(session.conn, session.transaction);
			session.statement->Get(i, temp);
			std::string data;
			temp->Load(data);
			f.SetValue(String(data));
			}
			break;
		default: {
			std::string temp;
			session.statement->Get(i, temp);
			String s(temp);
			f.SetValue(s);
		}
	}
}

void FireBirdConnection::Cancel()
{
	info.Clear();
	rows = 0;
	fetched_row = -1;
	
	if(session!=0 && session.transaction!=0)
	if(session.transaction->Started()) session.transaction->Rollback();
}

SqlSession& FireBirdConnection::GetSession() const
{
	return session;
}

String FireBirdConnection::GetUser() const
{
	return session.GetUser();
}

String FireBirdConnection::ToString() const
{
	return statement;
}

FireBirdConnection::FireBirdConnection(FireBirdSession& a_session, IBPP::Database a_conn)
  : session(a_session), conn(a_conn)
{
	last_id = Null;
}

/*Value PgSequence::Get()
{
#ifndef NOAPPSQL
	Sql sql(session ? *session : SQL.GetSession());
#else
	ASSERT(session);
	Sql sql(*session);
#endif
	if(!sql.Execute(Select(NextVal(seq)).Get()) || !sql.Fetch())
		return ErrorValue();
	return sql[0];
}*/

END_UPP_NAMESPACE

