##
#  Copyright : Copyright (c) MOSEK ApS, Denmark. All rights reserved.
#
#  File :      test_async.py
#
#  Purpose :   Demonstrates how to submit an optimization problem
#              to the MOSEK OptServer and solve it in asynchronous mode.
##

import requests, sys, argparse, time, json

if __name__ == '__main__':
    # Arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--url", help="URL of the remote server", type=str, required=True)
    parser.add_argument("--infile", help="Input file to optimize", type=str, required=True)
    parser.add_argument("--intype", help="Content type of the input, for example application/x-mosek-mps", type=str, required=True)
    parser.add_argument("--outtype", help="Content type of the response, for example application/x-mosek-json", type=str, required=True)
    parser.add_argument("--num-tries", help="How many times to poll for solution", type=int, default=5)
    args = parser.parse_args()
    URL = args.url
    infile = args.infile
    intype = args.intype
    outtype = args.outtype
    maxPolls = args.num_tries
    verify = False # Whether to verify SSL certificates

    # Create a connection
    token = ""
    with requests.Session() as s:
        with open(infile,'rb') as probdata:
            # POST problem data
            submit = s.post(URL + "/api/v1/submit", 
                            data = probdata,
                            headers = { "Content-Type" : intype },
                            verify = verify )
            if submit.status_code == requests.codes.ok:
                token = submit.headers['X-Mosek-Job-Token']
                print("Submit: success")

                # Request the server to solve the problem in the background
                solve = s.get(URL + "/api/v1/solve-background", 
                              headers = { "X-Mosek-Job-Token" : token },
                              verify = verify )
                if solve.status_code not in [requests.codes.ok, requests.codes.no_content]:
                    print(f"Error initiating solve, status = {solve.status_code}")
                    sys.exit(-1)
            else:
                print(f"Error submitting job, status = {submit.status_code}")
                sys.exit(-1)

    # Begin waiting for the solution
    solved = False
    pollCount = 0
    logOffset = 0

    with requests.Session() as s:
        while not solved:
            pollCount += 1
            sol = s.get(URL + "/api/v1/solution",
                        headers = { "X-Mosek-Job-Token" : token ,
                                    "Accept" : outtype },
                        verify = verify )

            if sol.status_code == requests.codes.no_content:
                # Solution no yet available
                print(f"Solution not available in poll {pollCount}, continuing")
                time.sleep(1.0)
            elif sol.status_code == requests.codes.ok:
                # Solution is available
                solved = True
                if outtype in ["application/json", "application/x-mosek-jtask"]:
                    solution = json.loads(sol.text)
                else:
                    solution = sol.text
                res = sol.headers["X-Mosek-Res-Code"]
                trm = sol.headers["X-Mosek-Trm-Code"]
            else:
                print(f"Error querying for solution, status = {sol.status_code}")

            # After too many attempts we indicate the solver to stop
            if not solved and pollCount >= maxPolls:
                s.get(URL + "/api/v1/break",
                      headers = { "X-Mosek-Job-Token" : token },
                      verify = verify )

            # Get the log from the last call until now
            log = s.get(URL + "/api/v1/log" + f"?offset={logOffset}", 
                    headers = { "X-Mosek-Job-Token" : token },
                    verify = verify )
            print(log.text)
            logOffset += len(log.text)

            if solved:
                print(f"Solution: {solution}")
                print(f"Response code:    {res}")
                print(f"Termination code: {trm}")
