#include "Image.h"

#include "_bmp.h"

bool LoadBMPHeader(Stream& stream, BMPHeader& header, bool icon)
{
	if(!stream.IsOpen())
		return false;
	ASSERT(stream.IsLoading());
	if(!icon)
	{
		BMP_FILEHEADER bmfh;
		// NUXI problem
		if(!stream.GetAll(&bmfh, sizeof(bmfh)))
			return false;
		bmfh.EndianSwap();
		if(bmfh.bfType != 'B' + 256 * 'M')
			return false;
	}
	Zero(header);
	if(!stream.GetAll(&header, sizeof(BMP_INFOHEADER)) || header.biSize < sizeof(BMP_INFOHEADER))
		return false;
	header.EndianSwap();
	if(header.biBitCount != 1 && header.biBitCount != 4 && header.biBitCount != 8
	&& header.biBitCount != 16 && header.biBitCount != 24 && header.biBitCount != 32)
		return false;
	if(header.biSizeImage == 0) {
		if(header.biCompression != 0 /* BI_RGB */)
			return false;
		header.biSizeImage = header.biHeight * (((header.biWidth * header.biBitCount + 31) >> 3) & -4);
	}
	stream.SeekCur(header.biSize - sizeof(BMP_INFOHEADER));
	if(header.biBitCount <= 8) {
		if(!stream.GetAll(header.palette, (header.biClrUsed ? header.biClrUsed : 1 << header.biBitCount) * sizeof(RGBQUAD)))
			return false;
	}
	else if(header.biBitCount == 16 || header.biBitCount == 32)
	{ // prepare 16-bit rgb masks & shifts
		if(header.biCompression == 0 /* BI_RGB */)
			;
		else if(header.biCompression == 3 /* BI_BITFIELDS */)
		{ // read bitfield masks
			if(!stream.GetAll(header.palette, 12))
				return false;
		}
		else
			return false;
	}
	if(header.biBitCount >= 16 && header.biClrUsed != 0)
		stream.SeekCur(header.biClrUsed * sizeof(BMP_RGB));
	return true;
}

Vector<Image> LoadBMP(Stream& stream, const Vector<int>& page_index, bool icon)
{
	Vector<Image> out;
	if(!stream.IsOpen() || page_index.IsEmpty())
		return out;
	ASSERT(stream.IsLoading());
	int count = 1;
	if(icon) {
		ICONDIR id;
		if(!stream.GetAll(&id, sizeof(id)))
			return out;
		if(id.idReserved != 0 || (id.idType != 1 && id.idType != 2) || id.idCount == 0)
			return out;
		count = id.idCount;
	}
	int dirpos = (int)stream.GetPos();
	for(int i = 0; i < page_index.GetCount(); i++) {
		int pg = page_index[i];
		if(icon) {
			stream.Seek(dirpos + minmax(pg, 0, count - 1) * sizeof(ICONDIRENTRY));
			ICONDIRENTRY ide;
			if(!stream.GetAll(&ide, sizeof(ide)))
				return out;
			ide.EndianSwap();
			if(ide.dwBytesInRes < sizeof(BITMAPINFOHEADER) || ide.dwBytesInRes > stream.GetSize()
			|| ide.dwImageOffset > (unsigned)(stream.GetSize() - ide.dwBytesInRes))
				return out;
			stream.Seek(ide.dwImageOffset + dirpos - sizeof(ICONDIR));
		}
		else
			stream.Seek(dirpos);
		BMPHeader header;
		if(!LoadBMPHeader(stream, header, icon))
			return out;
		Vector<RGBA> palette;
		RasterFormat fmt;
		switch(header.biBitCount) {
		case 1: fmt.Set1mf(); break;
		case 4: fmt.Set4mf(); break;
		case 8: fmt.Set8(); break;
		case 16:
			if(header.biCompression == 3 /* BI_BITFIELD */)
				fmt.Set16le(*(dword *)(header.palette + 0), *(dword *)(header.palette + 1), *(dword *)(header.palette + 2));
			else
				fmt.Set16le(31 << 10, 31 << 5, 31);
			break;
		case 24:
			fmt.Set24le(0xff0000, 0x00ff00, 0x0000ff);
			break;
		case 32:
			if(header.biCompression == 3 /* BI_BITFIELD */)
				fmt.Set32le(*(dword *)(header.palette + 0), *(dword *)(header.palette + 1), *(dword *)(header.palette + 2));
			else
				fmt.Set32le(0xff0000, 0x00ff00, 0x0000ff);
			break;
		}
		if(header.biBitCount <= 8) {
			palette.SetCount(1 << header.biBitCount);
			const BMP_RGB *q = header.palette;
			for(int i = 0; i < palette.GetCount(); i++, q++) {
				palette[i].r = q->rgbRed;
				palette[i].g = q->rgbGreen;
				palette[i].b = q->rgbBlue;
				palette[i].a = 255;
			}
		}

		Size size(header.biWidth, tabs(header.biHeight));
		if(icon)
			size.cy >>= 1;

		ImageBuffer ib(size);
		int row_bytes = (fmt.GetByteCount(size.cx) + 3) & ~3;
		Buffer<byte> scanline(row_bytes);
		for(int i = 0; i < size.cy; i++) {
			if(!stream.GetAll(scanline, row_bytes))
				return out;
			fmt.Read(ib[header.biHeight < 0 ? i : size.cy - i - 1], scanline, size.cx, palette);
		}
		out.Add(ib);
	}
	return out;
}

Image LoadBMP(Stream& stream)
{
	Vector<Image> img = LoadBMP(stream, Vector<int>() << 0, false);
	return img.GetCount() ? img[0] : Image();
}

Image LoadBMP(const char *filename)
{
	FileIn in(filename);
	return LoadBMP(in);
}

Image LoadICO(Stream& stream)
{
	Vector<Image> img = LoadBMP(stream, Vector<int>() << 0, true);
	return img.GetCount() ? img[0] : Image();
}

Image LoadICO(const char *filename)
{
	FileIn in(filename);
	return LoadICO(in);
}
