diff --git "a/examples/Zamba_2_1_2B.ipynb" "b/examples/Zamba_2_1_2B.ipynb" new file mode 100644--- /dev/null +++ "b/examples/Zamba_2_1_2B.ipynb" @@ -0,0 +1,2010 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "f2dbda2ff7ba4b019af0f17975444078": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_31a19787f40749c3bc4f3fdaf14135b1", + "IPY_MODEL_00aee0329e044d31bd4c4ebe24097f2c", + "IPY_MODEL_a2dc535a5a7e45629d16d0ffdcf485bb" + ], + "layout": "IPY_MODEL_a52adc24c6244cc9b7b9d1fd1c255515" + } + }, + "31a19787f40749c3bc4f3fdaf14135b1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0ee238c7f9bf439e9102126c9e9c5564", + "placeholder": "​", + "style": "IPY_MODEL_b8427e036de44f1fa305b74437578d70", + "value": "config.json: 100%" + } + }, + "00aee0329e044d31bd4c4ebe24097f2c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b560b0df040349fea0a1b05b4834c719", + "max": 1337, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c6a427f52ce94de99ed357d54f2b1725", + "value": 1337 + } + }, + "a2dc535a5a7e45629d16d0ffdcf485bb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c3ec07880d864473af81da2726d42dc0", + "placeholder": "​", + "style": "IPY_MODEL_e3f0244893824653bd080fd2d5b50c61", + "value": " 1.34k/1.34k [00:00<00:00, 19.8kB/s]" + } + }, + "a52adc24c6244cc9b7b9d1fd1c255515": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0ee238c7f9bf439e9102126c9e9c5564": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b8427e036de44f1fa305b74437578d70": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b560b0df040349fea0a1b05b4834c719": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c6a427f52ce94de99ed357d54f2b1725": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c3ec07880d864473af81da2726d42dc0": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e3f0244893824653bd080fd2d5b50c61": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a08ec1110d7c4fe3acd58b5d7cf3051d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7d36442fdd35419f91246fe08cede692", + "IPY_MODEL_867f946fe3a0459f965114d9585d4d76", + "IPY_MODEL_1e10ded196cd4fa58ccf4e755b476ad1" + ], + "layout": "IPY_MODEL_245953fc4e04473986803901006aa94e" + } + }, + "7d36442fdd35419f91246fe08cede692": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_79326f7e145f47c79d95ea2ea5d49441", + "placeholder": "​", + "style": "IPY_MODEL_4f81a97adef7417a8235efe631f65aff", + "value": "config.json: 100%" + } + }, + "867f946fe3a0459f965114d9585d4d76": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c64dbd4e5eaf4b8b96d9a7707606b249", + "max": 1337, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a2b9f172590c4d988f66cef60270d9ee", + "value": 1337 + } + }, + "1e10ded196cd4fa58ccf4e755b476ad1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_86ae8996a0be4ff0b4d5df3ed9b0629e", + "placeholder": "​", + "style": "IPY_MODEL_51be35e151df4b37bd42d9075667cd93", + "value": " 1.34k/1.34k [00:00<00:00, 14.2kB/s]" + } + }, + "245953fc4e04473986803901006aa94e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "79326f7e145f47c79d95ea2ea5d49441": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4f81a97adef7417a8235efe631f65aff": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c64dbd4e5eaf4b8b96d9a7707606b249": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a2b9f172590c4d988f66cef60270d9ee": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "86ae8996a0be4ff0b4d5df3ed9b0629e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "51be35e151df4b37bd42d9075667cd93": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3e5f50646e04474f92718dc4882cbe4b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f69224acb6874b9abce791367de857a2", + "IPY_MODEL_55e667af53844d1f9f4f55a339bbba87", + "IPY_MODEL_0afa674b2cae440b8aa229b557c8dc8f" + ], + "layout": "IPY_MODEL_0c4d444eb00545ceb6cdccfa77171202" + } + }, + "f69224acb6874b9abce791367de857a2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c5c27aa06cb745dba82328b6b422588e", + "placeholder": "​", + "style": "IPY_MODEL_7347744554da4cd6bdf9b3397c3da1d6", + "value": "model.safetensors: 100%" + } + }, + "55e667af53844d1f9f4f55a339bbba87": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_178e5bbfd3bf4ad081654f25110d253f", + "max": 2430175992, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_25c014e006d5444d93d8990d4646e1b4", + "value": 2430175992 + } + }, + "0afa674b2cae440b8aa229b557c8dc8f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_93cce87d3dfc46fc9ce2e84b4e485bbf", + "placeholder": "​", + "style": "IPY_MODEL_b5835f0bb0764998ac3b587335bc6638", + "value": " 2.43G/2.43G [00:20<00:00, 86.1MB/s]" + } + }, + "0c4d444eb00545ceb6cdccfa77171202": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c5c27aa06cb745dba82328b6b422588e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7347744554da4cd6bdf9b3397c3da1d6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "178e5bbfd3bf4ad081654f25110d253f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "25c014e006d5444d93d8990d4646e1b4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "93cce87d3dfc46fc9ce2e84b4e485bbf": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b5835f0bb0764998ac3b587335bc6638": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3cfa20cc4dd24b499eca1c36a7f65320": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_03f47e6071d945afa9f422cfb21a8e30", + "IPY_MODEL_5dc0b8840cf84516a9f7e3a5deaead80", + "IPY_MODEL_a736cede537d4124b5da9e72fb4d1ba1" + ], + "layout": "IPY_MODEL_5e321e040972405eb9312a5f43be4228" + } + }, + "03f47e6071d945afa9f422cfb21a8e30": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a189c93099a24d3796944497c3ae45b4", + "placeholder": "​", + "style": "IPY_MODEL_be36fa86b54c4003ba948f303e97b9ff", + "value": "generation_config.json: 100%" + } + }, + "5dc0b8840cf84516a9f7e3a5deaead80": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_752ffddbba474ad58534d5471f24085b", + "max": 137, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b1cb53eb6c5942d0bae3171e4eea22d8", + "value": 137 + } + }, + "a736cede537d4124b5da9e72fb4d1ba1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_85234acea70f4cf28c41e07bbb7eb8c2", + "placeholder": "​", + "style": "IPY_MODEL_3bddfd87e42440778a098dfdfaefa263", + "value": " 137/137 [00:00<00:00, 6.42kB/s]" + } + }, + "5e321e040972405eb9312a5f43be4228": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a189c93099a24d3796944497c3ae45b4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be36fa86b54c4003ba948f303e97b9ff": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "752ffddbba474ad58534d5471f24085b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b1cb53eb6c5942d0bae3171e4eea22d8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "85234acea70f4cf28c41e07bbb7eb8c2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3bddfd87e42440778a098dfdfaefa263": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "

Zamba 2-1.2B

" + ], + "metadata": { + "id": "h9DjgpWw2C27" + } + }, + { + "cell_type": "markdown", + "source": [ + "This is a mamba based model with some transformer blocks in it. It is on its own fork of the transformers library so there are extra steps in installing its dependencies." + ], + "metadata": { + "id": "-bIoeS_I2MvR" + } + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/Zyphra/transformers_zamba2.git" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TIdVR89J2fA4", + "outputId": "19ac8288-1fd8-4e53-b137-29aa8ecb2cad" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'transformers_zamba2'...\n", + "remote: Enumerating objects: 182990, done.\u001b[K\n", + "remote: Counting objects: 100% (200/200), done.\u001b[K\n", + "remote: Compressing objects: 100% (105/105), done.\u001b[K\n", + "remote: Total 182990 (delta 109), reused 152 (delta 78), pack-reused 182790 (from 1)\u001b[K\n", + "Receiving objects: 100% (182990/182990), 206.54 MiB | 9.13 MiB/s, done.\n", + "Resolving deltas: 100% (130848/130848), done.\n", + "Updating files: 100% (4310/4310), done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install -e transformers_zamba2" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "y22A557n2pVF", + "outputId": "039b0d88-927e-472b-94ce-27337858ee06" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Obtaining file:///content/transformers_zamba2\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n", + " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (3.16.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.24.6)\n", + "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (2024.5.15)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (2.32.3)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.19.1)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.4.5)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (4.66.5)\n", + "Collecting mamba-ssm==2.1.0 (from transformers==4.43.0.dev0)\n", + " Downloading mamba_ssm-2.1.0.tar.gz (84 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting ninja (from transformers==4.43.0.dev0)\n", + " Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (5.3 kB)\n", + "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from transformers==4.43.0.dev0) (0.8.0)\n", + "Collecting triton (from transformers==4.43.0.dev0)\n", + " Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\n", + "Collecting causal-conv1d==1.3.0.post1 (from transformers==4.43.0.dev0)\n", + " Downloading causal_conv1d-1.3.0.post1.tar.gz (8.1 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (2.4.0+cu121)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.43.0.dev0) (2024.6.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.43.0.dev0) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (3.8)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.43.0.dev0) (2024.8.30)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (1.13.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (3.1.4)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->causal-conv1d==1.3.0.post1->transformers==4.43.0.dev0) (1.3.0)\n", + "Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.2/307.2 kB\u001b[0m \u001b[31m18.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: transformers, causal-conv1d, mamba-ssm\n", + " Building editable for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for transformers: filename=transformers-4.43.0.dev0-0.editable-py3-none-any.whl size=17274 sha256=2671982986cfe13f2155d632e8fa0ea7dfea3fabcb3ca6f1f13b2dd2276e53fb\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-b5wf_y35/wheels/bf/e2/60/f15fb160f61df44c33784d7a3d842b1a6d21217a8d1d7572c6\n", + " Building wheel for causal-conv1d (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for causal-conv1d: filename=causal_conv1d-1.3.0.post1-cp310-cp310-linux_x86_64.whl size=103872557 sha256=bd8a9034d54da9480c8bc41af83660796430673d5ea41a00389646105cfcb324\n", + " Stored in directory: /root/.cache/pip/wheels/6a/25/f8/ffbc841b608eb39aaa91ff5d1a5b830b7780855dd65796b774\n", + " Building wheel for mamba-ssm (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for mamba-ssm: filename=mamba_ssm-2.1.0-cp310-cp310-linux_x86_64.whl size=323986419 sha256=4e6940c8e19fdc6d9a17171546a89ace24984f65917e51494bc6f1ed820097b4\n", + " Stored in directory: /root/.cache/pip/wheels/62/1a/a0/88447e865ca478b954b6560317096bd11c79fb03f6312a64bc\n", + "Successfully built transformers causal-conv1d mamba-ssm\n", + "Installing collected packages: ninja, triton, causal-conv1d, mamba-ssm, transformers\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.44.2\n", + " Uninstalling transformers-4.44.2:\n", + " Successfully uninstalled transformers-4.44.2\n", + "Successfully installed causal-conv1d-1.3.0.post1 mamba-ssm-2.1.0 ninja-1.11.1.1 transformers-4.43.0.dev0 triton-3.0.0\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install accelerate" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zike0qpHVtVm", + "outputId": "f7b08f0a-d79e-4175-f131-7c108e9d7575" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\n", + "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.1)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.4.0+cu121)\n", + "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.24.6)\n", + "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.8)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### NOTE\n", + "Restart session after installing the custom transformers fork for these changes to take effect." + ], + "metadata": { + "id": "OXJUYhrC42fd" + } + }, + { + "cell_type": "code", + "source": [ + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "import torch" + ], + "metadata": { + "id": "9gkimtr040n6" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "USE_MAMBA_KERNELS = True if torch.cuda.is_available() else False" + ], + "metadata": { + "id": "0B3xoCWY4-Qo" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 501, + "referenced_widgets": [ + "f2dbda2ff7ba4b019af0f17975444078", + "31a19787f40749c3bc4f3fdaf14135b1", + "00aee0329e044d31bd4c4ebe24097f2c", + "a2dc535a5a7e45629d16d0ffdcf485bb", + "a52adc24c6244cc9b7b9d1fd1c255515", + "0ee238c7f9bf439e9102126c9e9c5564", + "b8427e036de44f1fa305b74437578d70", + "b560b0df040349fea0a1b05b4834c719", + "c6a427f52ce94de99ed357d54f2b1725", + "c3ec07880d864473af81da2726d42dc0", + "e3f0244893824653bd080fd2d5b50c61", + "a08ec1110d7c4fe3acd58b5d7cf3051d", + "7d36442fdd35419f91246fe08cede692", + "867f946fe3a0459f965114d9585d4d76", + "1e10ded196cd4fa58ccf4e755b476ad1", + "245953fc4e04473986803901006aa94e", + "79326f7e145f47c79d95ea2ea5d49441", + "4f81a97adef7417a8235efe631f65aff", + "c64dbd4e5eaf4b8b96d9a7707606b249", + "a2b9f172590c4d988f66cef60270d9ee", + "86ae8996a0be4ff0b4d5df3ed9b0629e", + "51be35e151df4b37bd42d9075667cd93", + "3e5f50646e04474f92718dc4882cbe4b", + "f69224acb6874b9abce791367de857a2", + "55e667af53844d1f9f4f55a339bbba87", + "0afa674b2cae440b8aa229b557c8dc8f", + "0c4d444eb00545ceb6cdccfa77171202", + "c5c27aa06cb745dba82328b6b422588e", + "7347744554da4cd6bdf9b3397c3da1d6", + "178e5bbfd3bf4ad081654f25110d253f", + "25c014e006d5444d93d8990d4646e1b4", + "93cce87d3dfc46fc9ce2e84b4e485bbf", + "b5835f0bb0764998ac3b587335bc6638", + "3cfa20cc4dd24b499eca1c36a7f65320", + "03f47e6071d945afa9f422cfb21a8e30", + "5dc0b8840cf84516a9f7e3a5deaead80", + "a736cede537d4124b5da9e72fb4d1ba1", + "5e321e040972405eb9312a5f43be4228", + "a189c93099a24d3796944497c3ae45b4", + "be36fa86b54c4003ba948f303e97b9ff", + "752ffddbba474ad58534d5471f24085b", + "b1cb53eb6c5942d0bae3171e4eea22d8", + "85234acea70f4cf28c41e07bbb7eb8c2", + "3bddfd87e42440778a098dfdfaefa263" + ] + }, + "id": "kuKpJo2gVUg1", + "outputId": "f5c1963f-d60a-4a48-b6ea-d8090814c033" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "config.json: 0%| | 0.00/1.34k [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0minput_ids\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_text\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDEVICE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_new_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/generation/utils.py\u001b[0m in \u001b[0;36mgenerate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1968\u001b[0m \u001b[0;31m# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1969\u001b[0;31m result = self._sample(\n\u001b[0m\u001b[1;32m 1970\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1971\u001b[0m \u001b[0mlogits_processor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprepared_logits_processor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/generation/utils.py\u001b[0m in \u001b[0;36m_sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)\u001b[0m\n\u001b[1;32m 2911\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2912\u001b[0m \u001b[0;31m# forward pass to get next token\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2913\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2914\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2915\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msynced_gpus\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mthis_peer_finished\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)\u001b[0m\n\u001b[1;32m 1433\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1434\u001b[0m \u001b[0;31m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1435\u001b[0;31m outputs = self.model(\n\u001b[0m\u001b[1;32m 1436\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1437\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1277\u001b[0m )\n\u001b[1;32m 1278\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1279\u001b[0;31m layer_outputs = next(mamba_layers)(\n\u001b[0m\u001b[1;32m 1280\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[0mtransformer_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransformer_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/modeling_zamba2.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, transformer_hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)\u001b[0m\n\u001b[1;32m 982\u001b[0m )\n\u001b[1;32m 983\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_layernorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 984\u001b[0;31m hidden_states = self.mamba(\n\u001b[0m\u001b[1;32m 985\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 986\u001b[0m \u001b[0minference_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1553\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1560\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1561\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1563\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/content/transformers_zamba2/src/transformers/models/zamba2/mamba2_layer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, u, from_shared_proj, seqlen, seq_idx, inference_params, attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 228\u001b[0m ) # (B, L, self.d_ssm + 2 * ngroups * d_state)\n\u001b[1;32m 229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m xBC = causal_conv1d_fn(\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0mxBC\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"d 1 w -> d w\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py\u001b[0m in \u001b[0;36mcausal_conv1d_fn\u001b[0;34m(x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseqlen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \"\"\"\n\u001b[0;32m--> 121\u001b[0;31m return CausalConv1dFn.apply(\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;31m# See NOTE: [functorch vjp and autograd interaction]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_functorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap_dead_wrappers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_setup_ctx_defined\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/causal_conv1d/causal_conv1d_interface.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, x, weight, bias, seq_idx, initial_states, return_final_states, final_states_out, activation)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mfinal_states_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactivation\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"silu\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"swish\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m out = causal_conv1d_cuda.causal_conv1d_fwd(\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseq_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal_states_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m )\n", + "\u001b[0;31mRuntimeError\u001b[0m: Expected x.is_cuda() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)" + ] + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"Zyphra/Zamba2-1.2B\")\n", + "model = AutoModelForCausalLM.from_pretrained(\"Zyphra/Zamba2-1.2B\",\n", + " device_map=DEVICE,\n", + " torch_dtype=torch.bfloat16,\n", + " use_mamba_kernels=USE_MAMBA_KERNELS,\n", + " # force_download=True,\n", + " )\n", + "\n", + "input_text = \"A funny prompt would be \"\n", + "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(DEVICE)\n", + "\n", + "outputs = model.generate(**input_ids, max_new_tokens=100)\n", + "print(tokenizer.decode(outputs[0]))" + ] + }, + { + "cell_type": "code", + "source": [ + "help(tokenizer(input_text, return_tensors=\"pt\").to)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Jbft3pGh8c6X", + "outputId": "e6877b68-a529-4eea-eadd-572035761607" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Help on method to in module transformers.tokenization_utils_base:\n", + "\n", + "to(device: Union[str, ForwardRef('torch.device')]) -> 'BatchEncoding' method of transformers.tokenization_utils_base.BatchEncoding instance\n", + " Send all values to device by calling `v.to(device)` (PyTorch only).\n", + " \n", + " Args:\n", + " device (`str` or `torch.device`): The device to put the tensors on.\n", + " \n", + " Returns:\n", + " [`BatchEncoding`]: The same instance after modification.\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "help(AutoModelForCausalLM.from_pretrained)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r5f0R-jI4A1I", + "outputId": "8aadb38d-92d6-49dc-d944-0730b562db78" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Help on method from_pretrained in module transformers.models.auto.auto_factory:\n", + "\n", + "from_pretrained(*model_args, **kwargs) method of builtins.type instance\n", + " Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.\n", + " \n", + " The model class to instantiate is selected based on the `model_type` property of the config object (either\n", + " passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by\n", + " falling back to using pattern matching on `pretrained_model_name_or_path`:\n", + " \n", + " - **bart** -- [`BartForCausalLM`] (BART model)\n", + " - **bert** -- [`BertLMHeadModel`] (BERT model)\n", + " - **bert-generation** -- [`BertGenerationDecoder`] (Bert Generation model)\n", + " - **big_bird** -- [`BigBirdForCausalLM`] (BigBird model)\n", + " - **bigbird_pegasus** -- [`BigBirdPegasusForCausalLM`] (BigBird-Pegasus model)\n", + " - **biogpt** -- [`BioGptForCausalLM`] (BioGpt model)\n", + " - **blenderbot** -- [`BlenderbotForCausalLM`] (Blenderbot model)\n", + " - **blenderbot-small** -- [`BlenderbotSmallForCausalLM`] (BlenderbotSmall model)\n", + " - **bloom** -- [`BloomForCausalLM`] (BLOOM model)\n", + " - **camembert** -- [`CamembertForCausalLM`] (CamemBERT model)\n", + " - **code_llama** -- [`LlamaForCausalLM`] (CodeLlama model)\n", + " - **codegen** -- [`CodeGenForCausalLM`] (CodeGen model)\n", + " - **cohere** -- [`CohereForCausalLM`] (Cohere model)\n", + " - **cpmant** -- [`CpmAntForCausalLM`] (CPM-Ant model)\n", + " - **ctrl** -- [`CTRLLMHeadModel`] (CTRL model)\n", + " - **data2vec-text** -- [`Data2VecTextForCausalLM`] (Data2VecText model)\n", + " - **dbrx** -- [`DbrxForCausalLM`] (DBRX model)\n", + " - **electra** -- [`ElectraForCausalLM`] (ELECTRA model)\n", + " - **ernie** -- [`ErnieForCausalLM`] (ERNIE model)\n", + " - **falcon** -- [`FalconForCausalLM`] (Falcon model)\n", + " - **fuyu** -- [`FuyuForCausalLM`] (Fuyu model)\n", + " - **gemma** -- [`GemmaForCausalLM`] (Gemma model)\n", + " - **gemma2** -- [`Gemma2ForCausalLM`] (Gemma2 model)\n", + " - **git** -- [`GitForCausalLM`] (GIT model)\n", + " - **gpt-sw3** -- [`GPT2LMHeadModel`] (GPT-Sw3 model)\n", + " - **gpt2** -- [`GPT2LMHeadModel`] (OpenAI GPT-2 model)\n", + " - **gpt_bigcode** -- [`GPTBigCodeForCausalLM`] (GPTBigCode model)\n", + " - **gpt_neo** -- [`GPTNeoForCausalLM`] (GPT Neo model)\n", + " - **gpt_neox** -- [`GPTNeoXForCausalLM`] (GPT NeoX model)\n", + " - **gpt_neox_japanese** -- [`GPTNeoXJapaneseForCausalLM`] (GPT NeoX Japanese model)\n", + " - **gptj** -- [`GPTJForCausalLM`] (GPT-J model)\n", + " - **jamba** -- [`JambaForCausalLM`] (Jamba model)\n", + " - **jetmoe** -- [`JetMoeForCausalLM`] (JetMoe model)\n", + " - **llama** -- [`LlamaForCausalLM`] (LLaMA model)\n", + " - **mamba** -- [`MambaForCausalLM`] (Mamba model)\n", + " - **marian** -- [`MarianForCausalLM`] (Marian model)\n", + " - **mbart** -- [`MBartForCausalLM`] (mBART model)\n", + " - **mega** -- [`MegaForCausalLM`] (MEGA model)\n", + " - **megatron-bert** -- [`MegatronBertForCausalLM`] (Megatron-BERT model)\n", + " - **mistral** -- [`MistralForCausalLM`] (Mistral model)\n", + " - **mixtral** -- [`MixtralForCausalLM`] (Mixtral model)\n", + " - **mpt** -- [`MptForCausalLM`] (MPT model)\n", + " - **musicgen** -- [`MusicgenForCausalLM`] (MusicGen model)\n", + " - **musicgen_melody** -- [`MusicgenMelodyForCausalLM`] (MusicGen Melody model)\n", + " - **mvp** -- [`MvpForCausalLM`] (MVP model)\n", + " - **olmo** -- [`OlmoForCausalLM`] (OLMo model)\n", + " - **open-llama** -- [`OpenLlamaForCausalLM`] (OpenLlama model)\n", + " - **openai-gpt** -- [`OpenAIGPTLMHeadModel`] (OpenAI GPT model)\n", + " - **opt** -- [`OPTForCausalLM`] (OPT model)\n", + " - **pegasus** -- [`PegasusForCausalLM`] (Pegasus model)\n", + " - **persimmon** -- [`PersimmonForCausalLM`] (Persimmon model)\n", + " - **phi** -- [`PhiForCausalLM`] (Phi model)\n", + " - **phi3** -- [`Phi3ForCausalLM`] (Phi3 model)\n", + " - **plbart** -- [`PLBartForCausalLM`] (PLBart model)\n", + " - **prophetnet** -- [`ProphetNetForCausalLM`] (ProphetNet model)\n", + " - **qdqbert** -- [`QDQBertLMHeadModel`] (QDQBert model)\n", + " - **qwen2** -- [`Qwen2ForCausalLM`] (Qwen2 model)\n", + " - **qwen2_moe** -- [`Qwen2MoeForCausalLM`] (Qwen2MoE model)\n", + " - **recurrent_gemma** -- [`RecurrentGemmaForCausalLM`] (RecurrentGemma model)\n", + " - **reformer** -- [`ReformerModelWithLMHead`] (Reformer model)\n", + " - **rembert** -- [`RemBertForCausalLM`] (RemBERT model)\n", + " - **roberta** -- [`RobertaForCausalLM`] (RoBERTa model)\n", + " - **roberta-prelayernorm** -- [`RobertaPreLayerNormForCausalLM`] (RoBERTa-PreLayerNorm model)\n", + " - **roc_bert** -- [`RoCBertForCausalLM`] (RoCBert model)\n", + " - **roformer** -- [`RoFormerForCausalLM`] (RoFormer model)\n", + " - **rwkv** -- [`RwkvForCausalLM`] (RWKV model)\n", + " - **speech_to_text_2** -- [`Speech2Text2ForCausalLM`] (Speech2Text2 model)\n", + " - **stablelm** -- [`StableLmForCausalLM`] (StableLm model)\n", + " - **starcoder2** -- [`Starcoder2ForCausalLM`] (Starcoder2 model)\n", + " - **transfo-xl** -- [`TransfoXLLMHeadModel`] (Transformer-XL model)\n", + " - **trocr** -- [`TrOCRForCausalLM`] (TrOCR model)\n", + " - **whisper** -- [`WhisperForCausalLM`] (Whisper model)\n", + " - **xglm** -- [`XGLMForCausalLM`] (XGLM model)\n", + " - **xlm** -- [`XLMWithLMHeadModel`] (XLM model)\n", + " - **xlm-prophetnet** -- [`XLMProphetNetForCausalLM`] (XLM-ProphetNet model)\n", + " - **xlm-roberta** -- [`XLMRobertaForCausalLM`] (XLM-RoBERTa model)\n", + " - **xlm-roberta-xl** -- [`XLMRobertaXLForCausalLM`] (XLM-RoBERTa-XL model)\n", + " - **xlnet** -- [`XLNetLMHeadModel`] (XLNet model)\n", + " - **xmod** -- [`XmodForCausalLM`] (X-MOD model)\n", + " - **zamba2** -- [`Zamba2ForCausalLM`] (Zamba2 model)\n", + " \n", + " The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are\n", + " deactivated). To train the model, you should first set it back in training mode with `model.train()`\n", + " \n", + " Args:\n", + " pretrained_model_name_or_path (`str` or `os.PathLike`):\n", + " Can be either:\n", + " \n", + " - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.\n", + " - A path to a *directory* containing model weights saved using\n", + " [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.\n", + " - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In\n", + " this case, `from_tf` should be set to `True` and a configuration object should be provided as\n", + " `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a\n", + " PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.\n", + " model_args (additional positional arguments, *optional*):\n", + " Will be passed along to the underlying model `__init__()` method.\n", + " config ([`PretrainedConfig`], *optional*):\n", + " Configuration for the model to use instead of an automatically loaded configuration. Configuration can\n", + " be automatically loaded when:\n", + " \n", + " - The model is a model provided by the library (loaded with the *model id* string of a pretrained\n", + " model).\n", + " - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the\n", + " save directory.\n", + " - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a\n", + " configuration JSON file named *config.json* is found in the directory.\n", + " state_dict (*Dict[str, torch.Tensor]*, *optional*):\n", + " A state dictionary to use instead of a state dictionary loaded from saved weights file.\n", + " \n", + " This option can be used if you want to create a model from a pretrained configuration but load your own\n", + " weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and\n", + " [`~PreTrainedModel.from_pretrained`] is not a simpler option.\n", + " cache_dir (`str` or `os.PathLike`, *optional*):\n", + " Path to a directory in which a downloaded pretrained model configuration should be cached if the\n", + " standard cache should not be used.\n", + " from_tf (`bool`, *optional*, defaults to `False`):\n", + " Load the model weights from a TensorFlow checkpoint save file (see docstring of\n", + " `pretrained_model_name_or_path` argument).\n", + " force_download (`bool`, *optional*, defaults to `False`):\n", + " Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n", + " cached versions if they exist.\n", + " resume_download:\n", + " Deprecated and ignored. All downloads are now resumed by default when possible.\n", + " Will be removed in v5 of Transformers.\n", + " proxies (`Dict[str, str]`, *optional*):\n", + " A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',\n", + " 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n", + " output_loading_info(`bool`, *optional*, defaults to `False`):\n", + " Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.\n", + " local_files_only(`bool`, *optional*, defaults to `False`):\n", + " Whether or not to only look at local files (e.g., not try downloading the model).\n", + " revision (`str`, *optional*, defaults to `\"main\"`):\n", + " The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a\n", + " git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any\n", + " identifier allowed by git.\n", + " trust_remote_code (`bool`, *optional*, defaults to `False`):\n", + " Whether or not to allow for custom models defined on the Hub in their own modeling files. This option\n", + " should only be set to `True` for repositories you trust and in which you have read the code, as it will\n", + " execute code present on the Hub on your local machine.\n", + " code_revision (`str`, *optional*, defaults to `\"main\"`):\n", + " The specific revision to use for the code on the Hub, if the code leaves in a different repository than\n", + " the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based\n", + " system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier\n", + " allowed by git.\n", + " kwargs (additional keyword arguments, *optional*):\n", + " Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,\n", + " `output_attentions=True`). Behaves differently depending on whether a `config` is provided or\n", + " automatically loaded:\n", + " \n", + " - If a configuration is provided with `config`, `**kwargs` will be directly passed to the\n", + " underlying model's `__init__` method (we assume all relevant updates to the configuration have\n", + " already been done)\n", + " - If a configuration is not provided, `kwargs` will be first passed to the configuration class\n", + " initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that\n", + " corresponds to a configuration attribute will be used to override said attribute with the\n", + " supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute\n", + " will be passed to the underlying model's `__init__` function.\n", + " \n", + " Examples:\n", + " \n", + " ```python\n", + " >>> from transformers import AutoConfig, AutoModelForCausalLM\n", + " \n", + " >>> # Download model and configuration from huggingface.co and cache.\n", + " >>> model = AutoModelForCausalLM.from_pretrained(\"google-bert/bert-base-cased\")\n", + " \n", + " >>> # Update configuration during loading\n", + " >>> model = AutoModelForCausalLM.from_pretrained(\"google-bert/bert-base-cased\", output_attentions=True)\n", + " >>> model.config.output_attentions\n", + " True\n", + " \n", + " >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)\n", + " >>> config = AutoConfig.from_pretrained(\"./tf_model/bert_tf_model_config.json\")\n", + " >>> model = AutoModelForCausalLM.from_pretrained(\n", + " ... \"./tf_model/bert_tf_checkpoint.ckpt.index\", from_tf=True, config=config\n", + " ... )\n", + " ```\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "K62B1nCQ7ul3" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file