import type { QueryParams } from 'features/ModelQueryParams/types';
import type { Model } from 'shared/api/models/types';

import {
  getIsAudioClassification,
  getIsAutoSpeechRecognition,
  getIsDocumentQA,
  getIsImageGeneration,
  getIsOCR,
  getIsSentenceSimilarity,
  getIsSpeechToText,
  getIsTextToSpeech,
  getIsTextToVideo,
  getIsTimeSeriesForecasting,
  getIsUnconditionalImageGen,
} from 'features/ModelQueryParams/helpers/modelParamsChecks';

export const getModelParams = (queryParams: QueryParams, model: Model | undefined) => {
  if (!model) return {};

  if (getIsTextToSpeech(queryParams, model.type)) {
    const { samplingRate, ...remaining } = queryParams;
    return {
      ...remaining,
      sampling_rate: samplingRate,
    };
  }

  if (getIsSpeechToText(queryParams, model.type)) {
    const { profanityFilter, ...remaining } = queryParams;
    return {
      ...remaining,
      profanity_filter: profanityFilter,
    };
  }

  if (getIsAudioClassification(queryParams, model.type)) {
    const { returnTimeStamp, samplingRate, topK, ...remaining } = queryParams;
    return { return_timestamps: returnTimeStamp, sampling_rate: samplingRate, top_k: topK, ...remaining };
  }

  if (getIsDocumentQA(queryParams, model.type)) {
    const { returnConfidenceScores, topK, ...remaining } = queryParams;
    return { return_confidence_scores: returnConfidenceScores, top_k: topK, ...remaining };
  }

  if (getIsOCR(queryParams, model.type)) {
    const { returnBoundingBoxes, returnConfidenceScores, topK, ...remaining } = queryParams;
    return {
      return_bounding_boxes: returnBoundingBoxes,
      return_confidence_scores: returnConfidenceScores,
      top_k: topK,
      ...remaining,
    };
  }

  if (getIsUnconditionalImageGen(queryParams, model.type)) {
    const { guidanceScale, imageSize, numInferenceSteps, ...remaining } = queryParams;
    return {
      guidance_scale: guidanceScale,
      image_size: imageSize,
      num_inference_steps: numInferenceSteps,
      ...remaining,
    };
  }

  if (getIsAutoSpeechRecognition(queryParams, model.type)) {
    const { returnTimeStamp, samplingRate, ...remaining } = queryParams;
    return {
      return_timestamps: returnTimeStamp,
      sampling_rate: samplingRate,
      ...remaining,
    };
  }

  if (getIsSentenceSimilarity(queryParams, model.type)) {
    const { poolingMethod, returnEmbeddings } = queryParams;
    return {
      pooling_method: poolingMethod,
      private: queryParams.private,
      return_embeddings: returnEmbeddings,
    };
  }

  if (getIsImageGeneration(queryParams, model.type)) {
    const { aspectRatio, cfgScale, clipSkip, numInferenceSteps, ...remaining } = queryParams;
    return {
      'aspect-ratio': aspectRatio,
      cfg_scale: cfgScale,
      clip_skip: clipSkip,
      num_inference_steps: numInferenceSteps,
      ...remaining,
    };
  }

  if (getIsTextToVideo(queryParams, model.type)) {
    const { aspectRatio, cfgScale, numInferenceSteps, ...remaining } = queryParams;
    return {
      'aspect-ratio': aspectRatio,
      cfg_scale: cfgScale,
      num_inference_steps: numInferenceSteps,
      ...remaining,
    };
  }

  if (getIsTimeSeriesForecasting(queryParams, model.type)) {
    const {
      changePoints,
      dateColumn,
      evaluationMetric,
      forecastHorizon,
      learningRate,
      predictionColumn,
      ...remaining
    } = queryParams;
    return {
      change_points: changePoints,
      date_column: dateColumn,
      evaluation_metric: evaluationMetric,
      forecast_horizon: forecastHorizon,
      learning_rate: learningRate,
      prediction_column: predictionColumn,
      ...remaining,
    };
  }

  return {
    frequency_penalty: queryParams.frequencyPenalty || 0,
    presence_penalty: queryParams.presencePenalty || 0,
    temperature: queryParams.temperature || 1,
    top_p: queryParams.topP || 1,
  };
};
