import { Ref, computed, ref, watch } from '@vue/composition-api';
import _findKey from 'lodash/findKey';
import _find from 'lodash/find';
import { Tensor, InferenceSession } from 'onnxruntime-web';

import { Asset } from '@/types';
import { or } from '@/utils';
import { firstTruthy } from '@/utils';
import useApi from '@/api/useApi';
import { parse } from '@/components/ImageAnnotator/sam/npy';

export interface SamEmbeddings {
  model: string;
  tensor: Tensor;
}

function parseImageMediaSlug(asset: Asset): string | null {
  const mediaRecord = asset.media;

  const imageMediaSlug = _findKey(mediaRecord, val =>
    val.media_type.startsWith('image/')
  );

  return imageMediaSlug;
}

// Prioritize larger models if embeddings are available for those
const MODEL_PRIORITY = ['sam_vit_h', 'sam_vit_l', 'sam_vit_b'];

/**
 * Resolve which embeddings and model to use from those available
 */
function resolveModel(availableModels: string[]): string | undefined {
  for (const model of MODEL_PRIORITY) {
    if (_find(availableModels, model_ => model_ === model)) {
      return model;
    }
  }
  return undefined;
}

export function useSamEmbeddings({ asset }: { asset: Ref<Asset | undefined> }) {
  const loading = ref(false);
  const error = ref<Error>(null);
  const embeddings = ref<SamEmbeddings>(null);

  const api = useApi();

  async function fetch() {
    console.log(`Initializing embeddings for asset`, asset.value);
    embeddings.value = null;
    error.value = null;
    if (!asset.value) {
      loading.value = false;
      return;
    }

    // Media slug like "camera"
    const imageMediaSlug = parseImageMediaSlug(asset.value);

    if (!imageMediaSlug) {
      console.log(`No image in asset, no embeddings`);
      loading.value = false;
      return;
    }

    loading.value = true;

    try {
      // Load list of available embeddings
      const availableEmbeddings = await api.assets.listEmbeddings({
        assetId: asset.value.id,
        mediaSlug: imageMediaSlug
      });
      // Pick the first in the prioritized list
      // if any exist
      const availableEmbeddingModels = availableEmbeddings.map(e => e.model);

      if (availableEmbeddingModels.length === 0) {
        console.log(`No embeddings found for media`);
        return;
      }

      console.log(
        `Found ${availableEmbeddingModels.length} models with embeddings`,
        availableEmbeddingModels
      );
      const model = resolveModel(availableEmbeddingModels);
      if (!model) {
        console.log(`No SAM embeddings found`);
        return;
      }
      const embedding = _find(availableEmbeddings, e => e.model === model);

      if (embedding.format !== 'npy') {
        throw Error(
          `Unknown data format for embeddings: ${embedding.format}, expected: 'npy'`
        );
      }

      console.log(
        `Loading embeddings for model: ${embedding.model} from URL: ${embedding.url}`
      );
      const embeddingData = await api.getBinaryDataFromServer(embedding.url);
      console.log(
        `Downloaded embeddings of size: ${embeddingData.byteLength} B`
      );
      const npyData = parse(embeddingData);
      console.log(
        `Parsed NumPy data of type: ${npyData.dtype} and shape ${npyData.shape}`
      );

      const tensor = new Tensor('float32', npyData.data, npyData.shape);
      embeddings.value = { model, tensor };
    } catch (err) {
      console.error(`Failed retrieving SAM embeddings`, err);
      error.value = err;
    } finally {
      loading.value = false;
    }
  }

  watch(asset, fetch, { immediate: true });

  return { loading, error, embeddings };
}

export async function createSession(
  model: ArrayBuffer
): Promise<InferenceSession> {
  // TODO Initialize session only once
  const options = {
    executionProviders: ['wasm']
  };
  console.log(
    `Initializing inference session for model of size: ${model.byteLength} B`
  );
  const session = await InferenceSession.create(model, options);
  return session;
}

export function useSamModel({
  modelName,
  load
}: {
  modelName: Ref<string | undefined>;
  load: Ref<boolean>;
}) {
  const model = ref<ArrayBuffer>(null);
  const loading = ref(false);
  const error = ref<Error>(null);

  const api = useApi();

  async function init() {
    error.value = null;
    if (!load.value || !modelName.value) {
      model.value = null;
      return;
    }
    console.log(`Downloading model with name: ${modelName.value}...`);
    loading.value = true;
    try {
      model.value = await api.models.getModel(modelName.value);
      console.log(`Successfully loaded model`);
    } catch (err) {
      console.error(`Error loading SAM model: ${modelName.value}`, err);
      error.value = err;
    } finally {
      loading.value = false;
    }
  }

  watch([load, modelName], init, { immediate: true });

  return { model, loading, error };
}

export function useSamSession({ model }: { model: Ref<ArrayBuffer | null> }) {
  const session = ref<InferenceSession>(null);
  const loading = ref(false);
  const error = ref<Error>(null);

  async function init() {
    error.value = null;
    session.value = null;
    if (!model.value) {
      return;
    }
    console.log(`Initializing InferenceSession for SAM...`);
    loading.value = true;
    try {
      session.value = await createSession(model.value);
      console.log(`Successfully initialized InferenceSession for SAM`);
    } catch (err) {
      console.error('Failed initializing SAM InferenceSession', err);
      error.value = err;
    } finally {
      loading.value = false;
    }
  }
  watch(model, init, { immediate: true });
  return { session, loading, error };
}

export function useSAM({
  asset,
  load
}: {
  asset: Ref<Asset>;
  load: Ref<boolean>;
}) {
  console.log(`Preparing SAM...`);
  const {
    embeddings,
    loading: loadingEmbeddings,
    error: errorLoadingEmbeddings
  } = useSamEmbeddings({ asset });

  const modelName = computed(() => embeddings.value?.model);

  const {
    model,
    loading: loadingModel,
    error: errorLoadingModel
  } = useSamModel({ modelName, load });

  const {
    session,
    loading: loadingSession,
    error: errorLoadingSession
  } = useSamSession({ model });

  const loading = or(loadingEmbeddings, loadingModel, loadingSession);
  const error = firstTruthy(
    errorLoadingEmbeddings,
    errorLoadingModel,
    errorLoadingSession
  );

  return { embeddings, session, loading, error };
}
