""" Tests the collect_learnings function in the cli/collect module. """ import pytest # def test_collect_learnings(monkeypatch): # monkeypatch.setattr(rudder_analytics, "track", MagicMock()) # # model = "test_model" # temperature = 0.5 # steps = [simple_gen] # dbs = FileRepositories( # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # OnDiskRepository("/tmp"), # ) # dbs.input = { # "prompt": "test prompt\n with newlines", # "feedback": "test feedback", # } # code = "this is output\n\nit contains code" # dbs.logs = {steps[0].__name__: json.dumps([{"role": "system", "content": code}])} # dbs.memory = {"all_output.txt": "test workspace\n" + code} # # collect_learnings(model, temperature, steps, dbs) # # learnings = extract_learning( # model, temperature, steps, dbs, steps_file_hash=steps_file_hash() # ) # assert rudder_analytics.track.call_count == 1 # assert rudder_analytics.track.call_args[1]["event"] == "learning" # a = { # k: v # for k, v in rudder_analytics.track.call_args[1]["properties"].items() # if k != "timestamp" # } # b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"} # assert a == b # # assert json.dumps(code) in learnings.logs # assert code in learnings.workspace if __name__ == "__main__": pytest.main(["-v"])